#!/usr/bin/python3
# -*- coding: utf-8 -*-
import argparse
from collections import defaultdict
from datetime import datetime
import functools
import logging
import os
from pathlib import Path
import platform
import time
import tempfile
import hashlib

from project_settings import project_path, log_directory
import log

log.setup(log_directory=log_directory)

import gradio as gr
import torch
import torchaudio

from toolbox.k2_sherpa.examples import examples
from toolbox.k2_sherpa import decode, nn_models
from toolbox.k2_sherpa.utils import audio_convert

main_logger = logging.getLogger("main")


def get_args():
    parser = argparse.ArgumentParser()

    parser.add_argument(
        "--pretrained_model_dir",
        default=(project_path / "pretrained_models").as_posix(),
        type=str
    )
    args = parser.parse_args()
    return args


def update_model_dropdown(language: str):
    if language not in nn_models.model_map.keys():
        raise ValueError(f"Unsupported language: {language}")

    choices = nn_models.model_map[language]
    choices = [c["repo_id"] for c in choices]
    return gr.Dropdown(
        choices=choices,
        value=choices[0],
        interactive=True,
    )


def build_html_output(s: str, style: str = "result_item_success"):
    return f"""
    <div class='result'>
        <div class='result_item {style}'>
          {s}
        </div>
    </div>
    """


def md5_encrypt(text: str) -> str:
    """output str length: 32. """
    md = hashlib.md5()

    md.update(text.encode())

    result = md.hexdigest()
    return result


@torch.no_grad()
def process(
    language: str,
    repo_id: str,
    decoding_method: str,
    num_active_paths: int,
    add_punctuation: str,
    in_filename: str,
    pretrained_model_dir: Path,
):
    main_logger.info("language: {}".format(language))
    main_logger.info("repo_id: {}".format(repo_id))
    main_logger.info("decoding_method: {}".format(decoding_method))
    main_logger.info("num_active_paths: {}".format(num_active_paths))
    main_logger.info("in_filename: {}".format(in_filename))

    # audio convert
    in_filename = Path(in_filename)
    out_filename = Path(tempfile.gettempdir()) / "asr" / in_filename.name
    out_filename.parent.mkdir(parents=True, exist_ok=True)

    audio_convert(in_filename=in_filename.as_posix(),
                  out_filename=out_filename.as_posix(),
                  )

    # model settings
    m_list = nn_models.model_map.get(language)
    if m_list is None:
        raise AssertionError("language invalid: {}".format(language))

    m_dict = None
    for m in m_list:
        if m["repo_id"] == repo_id:
            m_dict = m
    if m_dict is None:
        raise AssertionError("repo_id invalid: {}".format(repo_id))

    # local_model_dir
    repo_id: Path = Path(repo_id)
    if len(repo_id.parts) == 1:
        repo_name = repo_id.parts[-1]
        if len(repo_name) > 40:
            repo_name = md5_encrypt(repo_name)
        # repo_name = repo_name[:40]
        folder = repo_name
    elif len(repo_id.parts) == 2:
        repo_supplier = repo_id.parts[-2]
        repo_name = repo_id.parts[-1]
        if len(repo_name) > 40:
            repo_name = md5_encrypt(repo_name)
        # repo_name = repo_name[:40]
        folder = "{}/{}".format(repo_supplier, repo_name)
    else:
        raise AssertionError("repo_id parts count invalid: {}".format(len(repo_id.parts)))

    local_model_dir = pretrained_model_dir / "huggingface" / folder

    # load recognizer
    recognizer = nn_models.load_recognizer(
        local_model_dir=local_model_dir,
        decoding_method=decoding_method,
        num_active_paths=num_active_paths,
        **m_dict
    )

    # transcribe
    now = datetime.now()
    date_time = now.strftime("%Y-%m-%d %H:%M:%S.%f")
    logging.info(f"Started at {date_time}")
    start = time.time()

    text = decode.decode_by_recognizer(recognizer=recognizer,
                                       filename=out_filename.as_posix(),
                                       )

    # load_punctuation_model
    if add_punctuation == "Yes":
        punctuation_repo_id = "csukuangfj/sherpa-onnx-punct-ct-transformer-zh-en-vocab272727-2024-04-12"
        local_model_dir = pretrained_model_dir / "huggingface" / md5_encrypt(punctuation_repo_id)
        punctuation_model = nn_models.load_punctuation_model(
            local_model_dir=local_model_dir,
            repo_id=punctuation_repo_id,
            nn_model_file="model.onnx",
            nn_model_file_sub_folder=".",
        )
        text = punctuation_model.add_punctuation(text)

    # statistics
    date_time = now.strftime("%Y-%m-%d %H:%M:%S.%f")
    end = time.time()

    metadata = torchaudio.info(out_filename.as_posix())
    duration = metadata.num_frames / 16000
    rtf = (end - start) / duration

    main_logger.info(f"Finished at {date_time} s. Elapsed: {end - start: .3f} s")

    info = f"""
    Wave duration  : {duration: .3f} s <br/>
    Processing time: {end - start: .3f} s <br/>
    RTF: {end - start: .3f}/{duration: .3f} = {rtf:.3f} <br/>
    """

    main_logger.info(info)
    main_logger.info(f"\nrepo_id: {repo_id}\nhyp: {text}")

    return text, build_html_output(info)


def process_uploaded_file(language: str,
                          repo_id: str,
                          decoding_method: str,
                          num_active_paths: int,
                          add_punctuation: str,
                          in_filename: str,
                          pretrained_model_dir: Path,
                          ):
    if in_filename is None or in_filename == "":
        return "", build_html_output(
            "Please first upload a file and then click "
            'the button "submit for recognition"',
            "result_item_error",
        )
    main_logger.info(f"Processing uploaded file: {in_filename}")

    try:
        return process(
            in_filename=in_filename,
            language=language,
            repo_id=repo_id,
            decoding_method=decoding_method,
            num_active_paths=num_active_paths,
            add_punctuation=add_punctuation,
            pretrained_model_dir=pretrained_model_dir,
        )
    except Exception as e:
        msg = "transcribe error: {}".format(str(e))
        main_logger.info(msg)
        return "", build_html_output(msg, "result_item_error")


# css style is copied from
# https://huggingface.co/spaces/alphacep/asr/blob/main/app.py#L113
css = """
.result {display:flex;flex-direction:column}
.result_item {padding:15px;margin-bottom:8px;border-radius:15px;width:100%}
.result_item_success {background-color:mediumaquamarine;color:white;align-self:start}
.result_item_error {background-color:#ff7070;color:white;align-self:start}
"""


def main():
    args = get_args()

    pretrained_model_dir = Path(args.pretrained_model_dir)
    pretrained_model_dir.mkdir(exist_ok=True)

    process_uploaded_file_ = functools.partial(
        process_uploaded_file,
        pretrained_model_dir=pretrained_model_dir,
    )

    title = "# Automatic Speech Recognition with Next-gen Kaldi"

    language_choices = list(nn_models.model_map.keys())

    language_to_models = defaultdict(list)
    for k, v in nn_models.model_map.items():
        for m in v:
            repo_id = m["repo_id"]
            language_to_models[k].append(repo_id)

    # blocks
    with gr.Blocks(css=css) as blocks:
        gr.Markdown(value=title)

        with gr.Tabs():
            with gr.TabItem("Upload from disk"):
                language_radio = gr.Radio(
                    label="Language",
                    choices=language_choices,
                    value=language_choices[0],
                )
                model_dropdown = gr.Dropdown(
                    choices=language_to_models[language_choices[0]],
                    label="Select a model",
                    value=language_to_models[language_choices[0]][0],
                    allow_custom_value=True
                )
                decoding_method_radio = gr.Radio(
                    label="Decoding method",
                    choices=["greedy_search", "modified_beam_search"],
                    value="greedy_search",
                )
                num_active_paths_slider = gr.Slider(
                    minimum=1,
                    value=4,
                    step=1,
                    label="Number of active paths for modified_beam_search",
                )
                punct_radio = gr.Radio(
                    label="Whether to add punctuation (Only for Chinese and English)",
                    choices=["Yes", "No"],
                    value="Yes",
                )

                uploaded_file = gr.Audio(
                    sources=["upload"],
                    type="filepath",
                    label="Upload from disk",
                )
                upload_button = gr.Button("Submit for recognition")
                uploaded_output = gr.Textbox(label="Recognized speech from uploaded file")
                uploaded_html_info = gr.HTML(label="Info")

                gr.Examples(
                    examples=examples,
                    inputs=[
                        language_radio,
                        model_dropdown,
                        decoding_method_radio,
                        num_active_paths_slider,
                        punct_radio,
                        uploaded_file,
                    ],
                    outputs=[uploaded_output, uploaded_html_info],
                    fn=process_uploaded_file_,
                )

            upload_button.click(
                process_uploaded_file_,
                inputs=[
                    language_radio,
                    model_dropdown,
                    decoding_method_radio,
                    num_active_paths_slider,
                    punct_radio,
                    uploaded_file,
                ],
                outputs=[uploaded_output, uploaded_html_info],
            )

        language_radio.change(
            update_model_dropdown,
            inputs=language_radio,
            outputs=model_dropdown,
        )

    blocks.queue().launch(
        share=False if platform.system() == "Windows" else False,
        server_name="127.0.0.1" if platform.system() == "Windows" else "0.0.0.0",
        server_port=7860
    )

    return


if __name__ == "__main__":
    main()