import os
import subprocess
import sys


def install(package):
    if '=' in package:
        package_name, package_version = package.split('==')
    else:
        package_name = package
        package_version = None
    try:
        subprocess.check_call([sys.executable, "-m", "pip", "uninstall", "-y", package_name])
        print(f"Successfully uninstalled {package}")
    except subprocess.CalledProcessError:
        print(f"Package {package} was not installed, proceeding with installation")
    subprocess.check_call([sys.executable, "-m", "pip", "install", package])

# install('pydantic==2.0.0')
# install('gradio==4.44.0')
# install('spacy==3.7')

debug = False
is_prod = True
if os.environ.get('PROD_MODE') == 'local':
    is_prod = False
else:
    debug = False

import pickle

import gradio as gr
import os

if not is_prod:

    import os
    os.environ['HF_HOME'] = '/proj/afosr/metavoice/cache'
    os.environ['TRANSFORMERS_CACHE'] = '/proj/afosr/metavoice/cache'
    os.environ['HF_DATASETS_CACHE'] = '/proj/afosr/metavoice/cache'
    os.environ['HF_METRICS_CACHE'] = '/proj/afosr/metavoice/cache'
    os.environ['HF_MODULES_CACHE'] = '/proj/afosr/metavoice/cache'
    ffmpeg_path = '/home/hc3295/ffmpegg_build/bin'
    os.environ['PATH'] += os.pathsep + ffmpeg_path


import torch
if not debug:
    import shutil
    import tempfile
    import time
    from pathlib import Path

    import librosa
    
    from huggingface_hub import snapshot_download

    from fam.llm.adapters import FlattenedInterleavedEncodec2Codebook
    from fam.llm.decoders import EncodecDecoder
    from fam.llm.fast_inference_utils import build_model, main
    from fam.llm.inference import (
        EncodecDecoder,
        InferenceConfig,
        Model,
        TiltedEncodec,
        TrainedBPETokeniser,
        get_cached_embedding,
        get_cached_file,    
        get_enhancer,
    )
    from fam.llm.utils import (
        check_audio_file,
        get_default_dtype,
        get_device,
        normalize_text,
    )



DESCRIPTION = ""
if not torch.cuda.is_available():
    DESCRIPTION += "\n<p>Running on CPU 🥶 This demo does not work on CPU.</p>"
if torch.cuda.is_available():
    if not debug:
        model_name = "metavoiceio/metavoice-1B-v0.1"
        seed = 1337
        output_dir = "outputs"
        _dtype = get_default_dtype()
        _device = 'cuda:0'

        _model_dir = snapshot_download(repo_id=model_name)
        first_stage_adapter = FlattenedInterleavedEncodec2Codebook(end_of_audio_token=1024)
        output_dir = output_dir
        os.makedirs(output_dir, exist_ok=True)

        second_stage_ckpt_path = f"{_model_dir}/second_stage.pt"
        config_second_stage = InferenceConfig(
            ckpt_path=second_stage_ckpt_path,
            num_samples=1,
            seed=seed,
            device=_device,
            dtype=_dtype,
            compile=False,
            init_from="resume",
            output_dir=output_dir,
        )
        data_adapter_second_stage = TiltedEncodec(end_of_audio_token=1024)
        llm_second_stage = Model(
            config_second_stage, TrainedBPETokeniser, EncodecDecoder, data_adapter_fn=data_adapter_second_stage.decode
        )
        enhancer = get_enhancer("df")

        precision = {"float16": torch.float16, "bfloat16": torch.bfloat16}[_dtype]
        model, tokenizer, smodel, model_size = build_model(
            precision=precision,
            checkpoint_path=Path(f"{_model_dir}/first_stage.pt"),
            spk_emb_ckpt_path=Path(f"{_model_dir}/speaker_encoder.pt"),
            device=_device,
            compile=True,
            compile_prefill=True,
        )

def generate_sample(text, emo_dir = None, source_path = None, emo_path = None, neutral_path = None, strength = 0.1, top_p = 0.95, guidance_scale = 3.0, preset_dropdown = None, toggle = None):

    print('text', text)
    print('emo_dir', emo_dir)
    print('source_path', source_path)
    print('emo_path', emo_path)
    print('neutral_path', neutral_path)
    print('strength', strength)
    print('top_p', top_p)
    print('guidance_scale', guidance_scale)

    if toggle == RADIO_CHOICES[0]:
        source_path = PRESET_VOICES[preset_dropdown]
    source_path = get_cached_file(source_path)
    check_audio_file(source_path)
    source_emb = get_cached_embedding(source_path, smodel).to(device=_device, dtype=precision)

    if emo_dir == EMO_NAMES[0]:
        emo_path = get_cached_file(emo_path)
        check_audio_file(emo_path)
        emo_emb = get_cached_embedding(emo_path, smodel).to(device=_device, dtype=precision)

        neutral_path = get_cached_file(neutral_path)
        check_audio_file(neutral_path)
        neutral_emb = get_cached_embedding(neutral_path, smodel).to(device=_device, dtype=precision)

        emo_dir = emo_emb - neutral_emb
        emo_dir = emo_dir / torch.norm(emo_dir, p=2)
    else:
        emo_dir = torch.tensor(ALL_EMO_DIRS[emo_dir], device=_device, dtype=precision)
    
    
    edited_emb = source_emb + strength * emo_dir
    edited_emb = edited_emb.to(device=_device, dtype=precision)

    temperature=1.0
    text = normalize_text(text)

    start = time.time()
    # first stage LLM
    tokens = main(
        model=model,
        tokenizer=tokenizer,
        model_size=model_size,
        prompt=text,
        spk_emb=edited_emb,
        top_p=torch.tensor(top_p, device=_device, dtype=precision),
        guidance_scale=torch.tensor(guidance_scale, device=_device, dtype=precision),
        temperature=torch.tensor(temperature, device=_device, dtype=precision),
    )
    text_ids, extracted_audio_ids = first_stage_adapter.decode([tokens])

    b_speaker_embs = edited_emb.unsqueeze(0)

    # second stage LLM + multi-band diffusion model
    wav_files = llm_second_stage(
        texts=[text],
        encodec_tokens=[torch.tensor(extracted_audio_ids, dtype=torch.int32, device=_device).unsqueeze(0)],
        speaker_embs=b_speaker_embs,
        batch_size=1,
        guidance_scale=None,
        top_p=None,
        top_k=200,
        temperature=1.0,
        max_new_tokens=None,
    )

    wav_file = wav_files[0]
    with tempfile.NamedTemporaryFile(suffix=".wav") as enhanced_tmp:
        enhancer(str(wav_file) + ".wav", enhanced_tmp.name)
        shutil.copy2(enhanced_tmp.name, str(wav_file) + ".wav")
        print(f"\nSaved audio to {wav_file}.wav")
    
    output_path = str(wav_file) + ".wav"
    return output_path


ALL_EMO_DIRS = pickle.load(open('all_emo_dirs.pkl', 'rb'))
EMO_NAMES = ['Upload your own sample'] + list(ALL_EMO_DIRS.keys())

RADIO_CHOICES = ["Preset voices", "Upload your voice"]
MAX_CHARS = 220
PRESET_VOICES = {
    # female
    "Bria": "https://cdn.themetavoice.xyz/speakers%2Fbria.mp3",
    # male
    "Alex": "https://cdn.themetavoice.xyz/speakers/alex.mp3",
    "Jacob": "https://cdn.themetavoice.xyz/speakers/jacob.wav",
}


def denormalise_top_p(top_p):
    # returns top_p in the range [0.9, 1.0]
    return round(0.9 + top_p / 100, 2)


def denormalise_guidance(guidance):
    # returns guidance in the range [1.0, 3.0]
    return 1 + ((guidance - 1) * (3 - 1)) / (5 - 1)


def _check_file_size(path):
    if not path:
        return
    filesize = os.path.getsize(path)
    filesize_mb = filesize / 1024 / 1024
    if filesize_mb >= 50:
        raise gr.Error(f"Please upload a sample less than 20MB for voice cloning. Provided: {round(filesize_mb)} MB")


def _handle_edge_cases(to_say, upload_target):
    if not to_say:
        raise gr.Error("Please provide text to synthesise")

    if len(to_say) > MAX_CHARS:
        gr.Warning(
            f"Max {MAX_CHARS} characters allowed. Provided: {len(to_say)} characters. Truncating and generating speech...Result at the end can be unstable as a result."
        )

    if not upload_target:
        return

    check_audio_file(upload_target)  # check file duration to be atleast 30s
    _check_file_size(upload_target)


def tts(to_say, top_p, guidance, toggle, preset_dropdown, upload_target):
    try:
        d_top_p = denormalise_top_p(top_p)
        d_guidance = denormalise_guidance(guidance)

        _handle_edge_cases(to_say, upload_target)

        to_say = to_say if len(to_say) < MAX_CHARS else to_say[:MAX_CHARS]

        return TTS_MODEL.synthesise(
            text=to_say,
            spk_ref_path=PRESET_VOICES[preset_dropdown] if toggle == RADIO_CHOICES[0] else upload_target,
            top_p=d_top_p,
            guidance_scale=d_guidance,
        )
    except Exception as e:
        raise gr.Error(f"Something went wrong. Reason: {str(e)}")


def change_voice_selection_layout(choice):
    if choice == RADIO_CHOICES[0]:
        return [gr.update(visible=True), gr.update(visible=False)]

    return [gr.update(visible=False), gr.update(visible=True)]

def change_emotion_selection_layout(choice):
    if choice == EMO_NAMES[0]:
        return [gr.update(visible=True), gr.update(visible=True), gr.update(visible=True)]
    else:
        return [gr.update(visible=False), gr.update(visible=False), gr.update(visible=False)]

title = """
<!-- Google Tag Manager -->
<script>(function(w,d,s,l,i){w[l]=w[l]||[];w[l].push({'gtm.start':
new Date().getTime(),event:'gtm.js'});var f=d.getElementsByTagName(s)[0],
j=d.createElement(s),dl=l!='dataLayer'?'&l='+l:'';j.async=true;j.src=
'https://www.googletagmanager.com/gtm.js?id='+i+dl;f.parentNode.insertBefore(j,f);
})(window,document,'script','dataLayer','GTM-5N27BQH8');</script>
<!-- End Google Tag Manager -->

</style>
<h1 style="margin-top: 10px;" class="page-title">Demo for <span style="margin-left: 10px;background-color: #E0FEE4;padding: 15px;border-radius: 10px;">🎛️ EmoKnob</span></h1>

<!-- Google Tag Manager (noscript) -->
<noscript><iframe src="https://www.googletagmanager.com/ns.html?id=GTM-5N27BQH8"
height="0" width="0" style="display:none;visibility:hidden"></iframe></noscript>
<!-- End Google Tag Manager (noscript) -->

"""

description = """

- EmoKnob applies control of emotion over arbitrary speaker.
- EmoKnob <b>extracts emotion from a pair of emotional and neutral audio from the same speaker.</b>
- In this demo, you can select from a few preset voices and upload your own emotional samples to clone.
- You can then apply control of a preset emotion or extract emotion from your own pair of emotional and neutral audio.
- You can adjust the strength of the emotion by using the slider.

Check out our [project page](https://emoknob.cs.columbia.edu/) for more details.

EmoKnob is uses [MetaVoice](https://github.com/metavoiceio/metavoice-src) as voice cloning backbone.
"""

with gr.Blocks(title="EmoKnob: EmoKnob: Enhance Voice Cloning with Fine-Grained Emotion Control") as demo:
    gr.Markdown(title)
    gr.Markdown(description)
    gr.Image("https://raw.githubusercontent.com/tonychenxyz/emoknob/main/docs/assets/emo-knob-teaser-1.svg", show_label=False, container=False)

    with gr.Row():
        with gr.Column():
            to_say = gr.TextArea(
                label=f"What should I say!? (max {MAX_CHARS} characters).",
                lines=4,
                value="To be or not to be, that is the question.",
            )



                # voice select
                
            with gr.Row(), gr.Column():
                toggle = gr.Radio(choices=RADIO_CHOICES, label="Choose voice", value=RADIO_CHOICES[0])
                

                with gr.Row() as row_1:
                    preset_dropdown = gr.Dropdown(
                        PRESET_VOICES.keys(), label="Preset voices", value=list(PRESET_VOICES.keys())[0]
                    )

                    with gr.Accordion("Preview: Preset voices", open=False):
                        for label, path in PRESET_VOICES.items():
                            gr.Audio(value=path, label=label)

                with gr.Row(visible=False) as row_2:
                    upload_target = gr.Audio(
                        sources=["upload"],
                        type="filepath",
                        label="Upload a clean sample to clone.",
                    )

                    
            with gr.Row(), gr.Column():
                strength = gr.Slider(
                        value=0.3,
                        minimum=0.0,
                        maximum=1.0,
                        step=0.01,
                        label="Strength - how strong the emotion is. Recommended value is between 0.0 and 0.6.",
                    )
                
                with gr.Row():
                    emotion_name = gr.Radio(choices=EMO_NAMES, label="Emotion", value=EMO_NAMES[1])  # Set default to second option



                with gr.Row(visible=False) as row_3:
                    upload_neutral = gr.Audio(
                        sources=["upload"],
                        type="filepath",
                        label="Neutral sample for emotion extraction.",
                    )

                    upload_emo = gr.Audio(
                        sources=["upload"],
                        type="filepath",
                        label="Emotional sample for emotion extraction.",
                    )

            with gr.Row(), gr.Column():
                # voice settings
                top_p = gr.Slider(
                    value=0.95,
                    minimum=0.0,
                    maximum=10.0,
                    step=1.0,
                    label="Speech Stability - improves text following for a challenging speaker",
                )
                guidance = gr.Slider(
                    value=3.0,
                    minimum=1.0,
                    maximum=5.0,
                    step=1.0,
                    label="Speaker similarity - How closely to match speaker identity and speech style.",
                )

            emotion_name.change(
                change_emotion_selection_layout,
                inputs=emotion_name,
                outputs=[row_3, upload_neutral, upload_emo],
            )

            toggle.change(
                change_voice_selection_layout,
                inputs=toggle,
                outputs=[row_1, row_2],
            )

        with gr.Column():
            speech = gr.Audio(
                type="filepath",
                label="Model says...",
            )

    submit = gr.Button("Generate Speech")
    submit.click(
        fn=generate_sample,
        inputs=[to_say, emotion_name, upload_target, upload_emo, upload_neutral, strength, top_p, guidance, preset_dropdown, toggle],
        outputs=speech,
    )


demo.launch()