from infer import OnnxInferenceSession
from text import cleaned_text_to_sequence, get_bert
from text.cleaner import clean_text
import numpy as np
from huggingface_hub import hf_hub_download
import asyncio
from pathlib import Path

OnnxSession = None

models = [
    {
        "local_path": "./bert/bert-large-cantonese",
        "repo_id": "hon9kon9ize/bert-large-cantonese",
        "files": [
            "pytorch_model.bin"
        ]
    },
    {
        "local_path": "./bert/deberta-v3-large",
        "repo_id": "microsoft/deberta-v3-large",
        "files": [
            "spm.model",
            "pytorch_model.bin"
        ]
    },
    {
        "local_path": "./onnx",
        "repo_id": "hon9kon9ize/bert-vits-zoengjyutgaai-onnx",
        "files": [
            "BertVits2.2PT.json",
            "BertVits2.2PT/BertVits2.2PT_enc_p.onnx",
            "BertVits2.2PT/BertVits2.2PT_emb.onnx",
            "BertVits2.2PT/BertVits2.2PT_dp.onnx",
            "BertVits2.2PT/BertVits2.2PT_sdp.onnx",
            "BertVits2.2PT/BertVits2.2PT_flow.onnx",
            "BertVits2.2PT/BertVits2.2PT_dec.onnx"
        ]
    }
]

def get_onnx_session():
    global OnnxSession

    if OnnxSession is not None:
        return OnnxSession

    OnnxSession = OnnxInferenceSession(
        {
            "enc": "onnx/BertVits2.2PT/BertVits2.2PT_enc_p.onnx",
            "emb_g": "onnx/BertVits2.2PT/BertVits2.2PT_emb.onnx",
            "dp": "onnx/BertVits2.2PT/BertVits2.2PT_dp.onnx",
            "sdp": "onnx/BertVits2.2PT/BertVits2.2PT_sdp.onnx",
            "flow": "onnx/BertVits2.2PT/BertVits2.2PT_flow.onnx",
            "dec": "onnx/BertVits2.2PT/BertVits2.2PT_dec.onnx",
        },
        Providers=["CPUExecutionProvider"],
    )
    return OnnxSession

def download_model_files(repo_id, files, local_path):
    for file in files:
        if not Path(local_path).joinpath(file).exists():
            hf_hub_download(
                repo_id, file, local_dir=local_path, local_dir_use_symlinks=False
            )

def download_models():
    for data in models:
        download_model_files(data["repo_id"], data["files"], data["local_path"])

def intersperse(lst, item):
    result = [item] * (len(lst) * 2 + 1)
    result[1::2] = lst
    return result

def get_text(text, language_str, style_text=None, style_weight=0.7):
    style_text = None if style_text == "" else style_text
    # 在此处实现当前版本的get_text
    norm_text, phone, tone, word2ph = clean_text(text, language_str)
    phone, tone, language = cleaned_text_to_sequence(phone, tone, language_str)

    # add blank
    phone = intersperse(phone, 0)
    tone = intersperse(tone, 0)
    language = intersperse(language, 0)
    for i in range(len(word2ph)):
        word2ph[i] = word2ph[i] * 2
    word2ph[0] += 1

    bert_ori = get_bert(
        norm_text, word2ph, language_str, "cpu", style_text, style_weight
    )
    del word2ph
    assert bert_ori.shape[-1] == len(phone), phone

    if language_str == "EN":
        en_bert = bert_ori
        yue_bert = np.random.randn(1024, len(phone))
    elif language_str == "YUE":
        en_bert = np.random.randn(1024, len(phone))
        yue_bert = bert_ori
    else:
        raise ValueError("language_str should be EN or YUE")

    assert yue_bert.shape[-1] == len(
        phone
    ), f"Bert seq len {yue_bert.shape[-1]} != {len(phone)}"

    phone = np.asarray(phone)
    tone = np.asarray(tone)
    language = np.asarray(language)
    en_bert = np.asarray(en_bert.T)
    yue_bert = np.asarray(yue_bert.T)

    return en_bert, yue_bert, phone, tone, language

# Text-to-speech function
async def text_to_speech(text, sid=0, language="YUE"):
    Session = get_onnx_session()
    if not text.strip():
        return None, gr.Warning("Please enter text to convert.")
    en_bert, yue_bert, x, tone, language = get_text(text, language)
    sid = np.array([sid])
    audio = Session(x, tone, language, en_bert, yue_bert, sid, sdp_ratio=0.4)

    return audio[0][0]


# Create Gradio application
import gradio as gr

# Gradio interface function
def tts_interface(text):
    audio = asyncio.run(text_to_speech(text, 0, "YUE"))
    return 44100, audio

async def create_demo():    
    description = """張悦楷粵語語音生成器,基於 Bert-VITS2 模型

本模型由 https://huggingface.co/datasets/laubonghaudoi/zoengjyutgaai_saamgwokjinji 張悦楷語音數據集訓練而得,所以係楷叔把聲。

注意:模型本身支持粵文同英文,但呢個 space 未實現中英夾雜生成。
"""
    
    demo = gr.Interface(
        fn=tts_interface,
        inputs=[
            gr.Textbox(label="Input Text", lines=5),
        ],
        outputs=[
            gr.Audio(label="Generated Audio"),
        ],
        examples=[
            ["漆黑之中我心眺望,不出一聲但兩眼發光\n寂寞極淒厲,晚風充滿汗,只因她幽怨目光"],
            ["本身我就係一個言出必達嘅人"],
            ["正話坐落喺龍椅上便,突然間,一朕狂風呼——哈噉吹起上嚟。"],
            ["幾日前我喺紅迪出咗個貼,關於學粵語嘅拼音。當時我呻嗰樣嘢係,要大家「為粵語學一套拼音」真係難。先唔好講係邊一個系統,而係我未學識説服大家。於是我就問大家,究竟大家“有幾抗拒”學粵語拼音。239人入面有 121 個話好樂意學粵拼。又有 19 人話,要佢學粵語拼音,佢寧願辭工唔做。其他人就係中間,返工要用就學啦,有錢收就學啦。呢個結果係預咗嘅。因為大家好多時候覺得學粵拼冇用。揾唔到食,揾唔到食,揾唔到食。我咀嚼咗幾日。究竟家長迫仔女學嘅嘢,學校教嘅嘢,有幾多係揾到食嘅呢?例如成日聽到啲人話,細個要學琴,學到八級就可以教琴㗎喇。事實上有幾多個學過琴嘅人,大個係教琴去揾食嘅呢?我都識好多人由細到大,都有學漢語拼音。又係嗰個問題……點解大家唔質疑嘅?乜唔係話幫唔到你揾食,就唔學㗎咩?所以我估……應該一切都係……太難?係咪粵語啲拼音太難?"],
            ["1950年春,廣東開始試行土改,到1951年夏天已在1500萬人口的地區鋪開。廣東省土改委員會主任由華南分局第三書記方方擔任。以林彪為第一書記,鄧子恢為第二書記的中共中央中南局,以及李雪峰為主任的中南局土改委員會, 在對廣東土改的評價上,一直同華南分局之間存在嚴重分歧。李雪峰多次在中南局機關報《長江日報》批評廣東土改群眾發動不夠,太右,是「和平土改」。毛澤東和中南局認為,需要改變廣東土改領導軟弱和進展緩慢的局面。1951年4月,中南局將中共南陽地委書記趙紫陽調到廣東,任華南分局秘書長,5月6日又增選為廣東省土改委員會副主任。1951年12月25日,又將廣西省委代理書記陶鑄調任華南分局第四書記,並接替方方主管廣東土改運動。此後,中南局正式提出了「廣東黨組織嚴重不純,要反對地方主義」的口號。廣東先後36次大規模進行「土改整隊」、「整肅」。到1952年5月,全省共處理廣東「地方主義」幹部6515人。期間,提出了「依靠大軍,依靠南下幹部,由大軍、南下幹部掛帥的方針」。"],
            ["嶺南大學,係廣州度一個經已消失咗嘅大學,原先喺1888年創校,係不隸屬於任何教派嘅基督教大學,係中華民國陣嘅13個基督教大學之一。學科最初係有英文、格致、理化、算術、地理、生物等西學課程,由1927年開始,原有文理學科外,開辦咗農、商、工、醫等學院。經過十年左右發展,成為咗中國大陸南方一個舉足輕重嘅私立大學。史堅如、陳毅、廖承志、冼星海、鄒至莊、曹安邦、陳香梅、姜伯駒等都係呢間學校出來嘅學生,喺全球各地設有同學會、校友會17個。"]
        ],
        title="Cantonese TTS Text-to-Speech 粵語語音合成",
        description=description,
        analytics_enabled=False,
        allow_flagging=False,
    )
    return demo


# Run the application
if __name__ == "__main__":
    download_models()

    demo = asyncio.run(create_demo())
    demo.launch()