import glob
import sys
from pathlib import Path
import shutil
import os
import zipfile

from espnet2.tasks.s2t import S2TTask
from espnet2.text.sentencepiece_tokenizer import SentencepiecesTokenizer
from espnet2.text.token_id_converter import TokenIDConverter
from espnet2.s2t.espnet_model import ESPnetS2TModel
from espnet2.bin.s2t_inference import Speech2Text
import espnetez as ez

import torch
import numpy as np
import logging
import gradio as gr
import librosa


def log(temp_dir, text):
    with open(f"{temp_dir}/output.log", "a") as f:
        f.write(text + "\n")


def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)


def get_dataset(data_path, data_info, test_count=10):
    # load data
    data = {}
    keys = []
    with open(f"{data_path}/text", "r", encoding="utf-8") as f:
        for line in f.readlines():
            audio_id, text = line.split(maxsplit=1)
            data[audio_id.strip()] = {"text": text.strip()}
            keys.append(audio_id.strip())

    # load text_ctc data
    with open(f"{data_path}/text_ctc", "r", encoding="utf-8") as f:
        for line in f.readlines():
            audio_id, text = line.split(maxsplit=1)
            data[audio_id.strip()]["text_ctc"] = text.strip()

    # load audio path
    for audio_path in glob.glob(f"{data_path}/audio/*"):
        audio_id = Path(audio_path).stem
        data[audio_id]["audio_path"] = audio_path
    
    # Convert to list
    data = [{
        'id': audio_id,
        'text': data[audio_id]['text'],
        'text_ctc': data[audio_id]['text_ctc'],
        'audio_path': data[audio_id]['audio_path'],
    } for audio_id in keys]

    return ez.dataset.ESPnetEZDataset(data[test_count:], data_info), ez.dataset.ESPnetEZDataset(data[:test_count], data_info), data[:test_count]


class CustomFinetuneModel(ESPnetS2TModel):
    def __init__(self, model, tempdir_path, log_every=500):
        super().__init__(
            vocab_size=model.vocab_size,
            token_list=model.token_list,
            frontend=model.frontend,
            specaug=model.specaug,
            normalize=model.normalize,
            preencoder=model.preencoder,
            encoder=model.encoder,
            postencoder=model.postencoder,
            decoder=model.decoder,
            ctc=model.ctc,
            ctc_weight=model.ctc_weight,
            interctc_weight=model.interctc_weight,
            ignore_id=model.ignore_id,
            lsm_weight=0.0,
            length_normalized_loss=False,
            report_cer=False,
            report_wer=False,
            sym_space="<space>",
            sym_blank="<blank>",
            sym_sos = "<sos>",
            sym_eos = "<eos>",
            sym_sop = "<sop>",  # start of prev
            sym_na = "<na>",  # not available
            extract_feats_in_collect_stats=model.extract_feats_in_collect_stats,
        )
        self.iter_count = 0
        self.log_every = log_every
        self.log_stats = {
            'loss': 0.0,
            'acc': 0.0
        }
        self.tempdir_path = tempdir_path
    
    def forward(self, *args, **kwargs):
        out = super().forward(*args, **kwargs)
        self.log_stats['loss'] += out[1]['loss'].item()
        self.log_stats['acc'] += out[1]['acc'].item()

        self.iter_count += 1
        if self.iter_count % self.log_every == 0:
            loss = self.log_stats['loss'] / self.log_every
            acc = self.log_stats['acc'] / self.log_every
            log(self.tempdir_path, f"[{self.iter_count}] - loss: {loss:.3f} - acc: {acc:.3f}")
            self.log_stats['loss'] = 0.0
            self.log_stats['acc'] = 0.0

        return out


def finetune_model(lang, task, tempdir_path, log_every, max_epoch, scheduler, warmup_steps, optimizer, learning_rate, weight_decay):
    """Main function for finetuning the model."""

    log(tempdir_path, "Start generating baseline...")
    gr.Info("Start generating baseline...")
    ref, base = baseline_model(lang, task, tempdir_path)
    
    log(tempdir_path, "Start generating hypothesis...")
    gr.Info("Start Fine-tuning process...")
    if len(tempdir_path) == 0:
        raise gr.Error("Please upload a zip file first.")

    # define tokenizer
    tokenizer = SentencepiecesTokenizer("assets/owsm_ebf_v3.1_base/bpe.model")
    converter = TokenIDConverter("assets/owsm_ebf_v3.1_base/tokens.txt")

    def tokenize(text):
        return np.array(converter.tokens2ids(tokenizer.text2tokens(text)))

    data_info = {
        "speech": lambda d: librosa.load(d["audio_path"], sr=16000)[0],
        "text": lambda d: tokenize(f"<{lang}><{task}><notimestamps> {d['text']}"),
        "text_ctc": lambda d: tokenize(d["text_ctc"]),
        "text_prev": lambda d: tokenize("<na>"),
    }

    # load dataset and define data_info
    train_dataset, test_dataset, test_list = get_dataset(tempdir_path, data_info)
    log(tempdir_path, "Loading dataset...")
    gr.Info("Loaded dataset.")

    # load and update configuration
    log(tempdir_path, "Setting up the training configuration...")
    pretrain_config = ez.config.from_yaml(
        "s2t",
        "assets/owsm_ebf_v3.1_base/config.yaml",
    )
    finetune_config = ez.config.update_finetune_config(
        "s2t", pretrain_config, "assets/owsm_ebf_v3.1_base/owsm_finetune_base.yaml"
    )
    finetune_config['max_epoch'] = max_epoch
    finetune_config['optim'] = optimizer
    finetune_config['optim_conf']['lr'] = learning_rate
    finetune_config['optim_conf']['weight_decay'] = weight_decay
    finetune_config['scheduler'] = scheduler
    finetune_config['scheduler_conf']['warmup_steps'] = warmup_steps
    finetune_config['multiple_iterator'] = False
    finetune_config['num_iters_per_epoch'] = None
    finetune_config['multiprocessing_distributed'] = False

    def build_model_fn(args):
        model, _ = S2TTask.build_model_from_file(
            "assets/owsm_ebf_v3.1_base/config.yaml",
            "assets/owsm_ebf_v3.1_base/owsm_v3.1_base.trained.pth",
            device="cuda" if torch.cuda.is_available() else "cpu",
        )
        model.train()
        log(tempdir_path, f'Trainable parameters: {count_parameters(model)}')
        model = CustomFinetuneModel(model, tempdir_path, log_every=log_every)
        return model

    trainer = ez.Trainer(
        task='s2t',
        train_config=finetune_config,
        train_dataset=train_dataset,
        valid_dataset=test_dataset,
        build_model_fn=build_model_fn, # provide the pre-trained model
        data_info=data_info,
        output_dir=f"{tempdir_path}/exp/finetune",
        stats_dir=f"{tempdir_path}/exp/stats",
        ngpu=1
    )
    gr.Info("start collect stats")
    log(tempdir_path, "Start collect stats process...")
    trainer.collect_stats()

    gr.Info("Finished collect stats, starting training.")
    log(tempdir_path, "Finished collect stats, starting training...")
    trainer.train()
    gr.Info("Finished Fine-tuning!")

    gr.Info("Start generating output for test set!")
    log(tempdir_path, "Start generating output for test set!")

    del trainer
    model = Speech2Text(
        "assets/owsm_ebf_v3.1_base/config.yaml",
        "assets/owsm_ebf_v3.1_base/owsm_v3.1_base.trained.pth",
        device="cuda" if torch.cuda.is_available() else "cpu",
        token_type="bpe",
        bpemodel="assets/owsm_ebf_v3.1_base/bpe.model",
        beam_size=5,
        ctc_weight=0.0,
        lang_sym=f"<{lang}>",
        task_sym=f"<{task}>",
    )
    model.s2t_model.eval()
    d = torch.load(f"{tempdir_path}/exp/finetune/valid.acc.ave.pth")
    model.s2t_model.load_state_dict(d)
    
    hyp = ""
    with open(f"{tempdir_path}/hyp.txt", "w") as f_hyp:
        for i in range(len(test_list)):
            data = test_list[i]
            out = model(librosa.load(data['audio_path'], sr=16000)[0])[0][3]
            f_hyp.write(out + '\n')
            hyp += out + '\n'

    log(tempdir_path, "Finished fine-tuning.")
    log(tempdir_path, "Start archiving experiment files...")
    log(tempdir_path, "Create zip file for the following files into `finetune.zip`:")
    log(tempdir_path, "exp/s2t_stats_raw_bpe50000")
    log(tempdir_path, "exp/finetune/tensorboard")
    log(tempdir_path, "exp/finetune/images")
    log(tempdir_path, "exp/finetune/config.yaml")
    log(tempdir_path, "exp/finetune/valid.acc.ave.pth")
    log(tempdir_path, "output.log")
    log(tempdir_path, "acc.png")
    log(tempdir_path, "loss.png")
    log(tempdir_path, "base.txt")
    log(tempdir_path, "hyp.txt")
    log(tempdir_path, "ref.txt")

    finetune_zip = zipfile.ZipFile(f"{tempdir_path}/finetune.zip", "w", zipfile.ZIP_DEFLATED)
    finetune_zip.write(f"{tempdir_path}/exp/stats")
    finetune_zip.write(f"{tempdir_path}/exp/finetune/tensorboard")
    finetune_zip.write(f"{tempdir_path}/exp/finetune/images")
    finetune_zip.write(f"{tempdir_path}/exp/finetune/config.yaml")
    finetune_zip.write(f"{tempdir_path}/exp/finetune/valid.acc.ave.pth")
    finetune_zip.write(f"{tempdir_path}/output.log")
    finetune_zip.write(f"{tempdir_path}/acc.png")
    finetune_zip.write(f"{tempdir_path}/loss.png")
    finetune_zip.write(f"{tempdir_path}/base.txt")
    finetune_zip.write(f"{tempdir_path}/hyp.txt")
    finetune_zip.write(f"{tempdir_path}/ref.txt")
    finetune_zip.close()

    gr.Info("Finished generating result file in zip!")
    log(tempdir_path, "Finished generating result file in zip!")
    
    return [f"{tempdir_path}/finetune.zip", f"{tempdir_path}/ref.txt", f"{tempdir_path}/base.txt", f"{tempdir_path}/hyp.txt"], ref, base, hyp


def baseline_model(lang, task, tempdir_path):
    log(tempdir_path, "Start loading dataset...")
    if len(tempdir_path) == 0:
        log(tempdir_path, "Please upload a zip file first.")
        raise gr.Error("Please upload a zip file first.")

    # define tokenizer
    tokenizer = SentencepiecesTokenizer("assets/owsm_ebf_v3.1_base/bpe.model")
    converter = TokenIDConverter("assets/owsm_ebf_v3.1_base/tokens.txt")

    def tokenize(text):
        return np.array(converter.tokens2ids(tokenizer.text2tokens(text)))

    data_info = {
        "speech": lambda d: librosa.load(d["audio_path"], sr=16000)[0],
        "text": lambda d: tokenize(f"<{lang}><{task}><notimestamps> {d['text']}"),
        "text_ctc": lambda d: tokenize(d["text_ctc"]),
        "text_prev": lambda d: tokenize("<na>"),
    }

    # load dataset and define data_info
    train_dataset, test_dataset, test_list = get_dataset(tempdir_path, data_info)
    log(tempdir_path, "Loaded dataset.")
    gr.Info("Loaded dataset.")

    gr.Info("Loading pretrained model...")
    log(tempdir_path, "Loading pretrained model...")

    model = Speech2Text(
        "assets/owsm_ebf_v3.1_base/config.yaml",
        "assets/owsm_ebf_v3.1_base/owsm_v3.1_base.trained.pth",
        device="cuda" if torch.cuda.is_available() else "cpu",
        token_type="bpe",
        bpemodel="assets/owsm_ebf_v3.1_base/bpe.model",
        beam_size=5,
        ctc_weight=0.3,
        lang_sym=f"<{lang}>",
        task_sym=f"<{task}>",
    )
    model.s2t_model.eval()
    
    base = ""
    ref = ""
    with open(f"{tempdir_path}/base.txt", "w") as f_base, open(f"{tempdir_path}/ref.txt", "w") as f_ref:
        for i in range(len(test_list)):
            data = test_list[i]
            f_ref.write(data['text'] + '\n')
            out = model(librosa.load(data['audio_path'], sr=16000)[0])[0][3]
            f_base.write(out + '\n')
            ref += data['text'] + '\n'
            base += out + '\n'

    return ref, base