import gradio as gr
import subprocess
from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC, AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
import torch
import librosa
import tempfile
from neon_tts_plugin_coqui import CoquiTTS
from gtts import gTTS
from numba import cuda

#variables
language_input_audio = 'en'
language_output_audio='ch'
dict_lang = {
    'en': 'eng_latn',
    'es': 'spa_Latn',
    'fr': 'fra_Latn',
    'de': 'deu_Latn',
    'pl': 'pol_Latn',
    'uk': 'ukr_Cyrl',
    'ro': 'ron_Latn',
    'hu': 'hun_Latn',
    'bg': 'bul_Cyrl',
    'nl': 'nld_Latn',
    'fi': 'fin_Latn',
    'sl': 'slv_Latn',
    'lv': 'lvs_Latn',
    'ga': 'gle_Latn',
    'ch': 'zho_Hant',
    'ru': 'rus_Cyrl'
    }

#functions
def radio_lang_input(lang):
    language_input_audio = lang
    return {var: language_input_audio}

#a function that determines the language of the output audio
def radio_input(lang):
    language_output_audio = lang
    return {var_lang: language_output_audio}

##
#convert input video file to text, audio, video
def video_load(video, language_input_audio, language_output_audio):
    #convert video to video720p -s 1280x720
    #
    subprocess.run(f'ffmpeg -y -i {video} -vf scale=720:-2 video720p.mp4', shell=True)
    #convert video to audio
    #
    subprocess.run('ffmpeg -y -i video720p.mp4 -vn -ar 16000 -ac 2 -ab 192K -f wav sound_from_input_video.wav', shell=True)
    #convert audio to text
    #
    # load model and tokenizer
    if language_input_audio == 'en':
        processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h")
        model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-960h")
        audio, rate = librosa.load('sound_from_input_video.wav', sr = 16000)
        input_values = processor(audio, sampling_rate=rate, return_tensors="pt", padding="longest").input_values
        # retrieve logits
        logits = model(input_values).logits
        # take argmax and decode
        predicted_ids = torch.argmax(logits, dim=-1)
        transcription = processor.batch_decode(predicted_ids)[0]
    if language_input_audio == 'ru':
        processor = Wav2Vec2Processor.from_pretrained("jonatasgrosman/wav2vec2-large-xlsr-53-russian")
        model = Wav2Vec2ForCTC.from_pretrained("jonatasgrosman/wav2vec2-large-xlsr-53-russian")
        audio, rate = librosa.load('sound_from_input_video.wav', sr = 16000)
        input_values = processor(audio, sampling_rate=rate, return_tensors="pt", padding="longest").input_values
        # retrieve logits
        logits = model(input_values).logits
        # take argmax and decode
        predicted_ids = torch.argmax(logits, dim=-1)
        transcription = processor.batch_decode(predicted_ids)[0]
    #convert text to text translations
    #
    model = AutoModelForSeq2SeqLM.from_pretrained("facebook/nllb-200-distilled-600M")
    tokenizer = AutoTokenizer.from_pretrained("facebook/nllb-200-distilled-600M")
    device = 0 if torch.cuda.is_available() else -1
    translation_pipeline = pipeline("translation", model=model, tokenizer=tokenizer, src_lang=dict_lang[language_input_audio], tgt_lang=dict_lang[language_output_audio], max_length=2000000, device=-1)
    result = translation_pipeline(transcription)
    text_translations = result[0]['translation_text']
    #convert text to audio
    #
    #ru
    if language_output_audio == 'ru':
        tts = gTTS(text_translations, lang='ru')
        tts.save('ru.mp3')
        audio = 'ru.mp3'
    #Vashington obcom
    if language_output_audio in ['en', 'es', 'fr', 'de', 'pl', 'uk', 'ro', 'hu', 'bg', 'nl', 'fi', 'sl', 'lv', 'ga']:
        coquiTTS = CoquiTTS()
        with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as fp:
            coquiTTS.get_tts(text_translations, fp, speaker = {"language" : language_output_audio})
        audio = fp.name
    #Chineese
    if language_output_audio == 'ch':
        tts = gTTS(text_translations, lang='zh-CN')
        tts.save('china.mp3')
        audio = 'china.mp3'
    #audio to video
    #
    subprocess.run(f'python inference.py --checkpoint_path wav2lip_gan.pth --face video720p.mp4 --audio {audio} --nosmooth --pads 0 20 0 0', shell=True)
    video = 'results/result_voice.mp4'
    return text_translations, audio, video

##
# function for create video from audio
def audio_to_video_custom(audio):
    subprocess.run(f'python inference.py --checkpoint_path wav2lip_gan.pth --face video720p.mp4 --audio {audio} --nosmooth --pads 0 20 0 0', shell=True)
    video = 'results/result_voice.mp4'
    return video

##
# function for create audio from custom translations
def text_to_audio_custom(text_translations, language_output_audio):
    #ru
    if language_output_audio == 'ru':
        tts = gTTS(text_translations, lang='ru')
        tts.save('ru.mp3')
        audio = 'ru.mp3'
    #Vashington obcom
    if language_output_audio in ['en', 'es', 'fr', 'de', 'pl', 'uk', 'ro', 'hu', 'bg', 'nl', 'fi', 'sl', 'lv', 'ga']:
        coquiTTS = CoquiTTS()
        with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as fp:
            coquiTTS.get_tts(text_translations, fp, speaker = {"language" : language_output_audio})
        audio = fp.name

    #Chineese
    if language_output_audio == 'ch':
        tts = gTTS(text_translations, lang='zh-CN')
        tts.save('china.mp3')
        audio = 'china.mp3'
    return audio

##### blocks
with gr.Blocks(title="Speak video in any language") as demo:
    # state variable
    var = gr.State('en')
    var_lang = gr.State('ch')
    # markdown text
    gr.Markdown("Service for translating videos into other languages ​​with support for the speaker's facial expressions")
    gr.Markdown("The uploaded video must be only with a face. Preferably without sudden movements of the head.")
    with gr.Row():
        with gr.Column():
            # radio button for change input lang
            radio_input_lang_video = gr.Radio(['en', 'ru'], value="en", label='Select input video language')
            # video input
            seed = gr.Video(label="Input Video")
            # radio button for change to output language
            radio = gr.Radio(['en', 'es', 'fr', 'de', 'pl', 'uk', 'ro', 'hu', 'bg', 'nl', 'fi', 'sl', 'lv', 'ga', 'ch', 'ru'], value="ch", label='Choose the language you want to speak')
            # main button
            btn_1 = gr.Button("1. Generate video with translated audio")

        with gr.Column():
            # text output
            translations_text = gr.Text(label="Generated Translations Text", interactive=True)
            # button to generate text to audio
            btn_3 = gr.Button("Generate custom translations to speech")
            # output audio
            translations_audio = gr.Audio(label="Generated Translations Audio", interactive=True, type="filepath")
            # button to generate audio to video
            btn_2 = gr.Button("Generate video with custom audio")
            # video output
            video_output = gr.Video(interactive=False, label="Generated Translations Video")
    # change input lang video
    radio_input_lang_video.change(fn=radio_lang_input, inputs=radio_input_lang_video, outputs=var)
    # change output lang
    radio.change(fn=radio_input, inputs=radio, outputs=var_lang)
    # main button click
    btn_1.click(video_load, inputs=[seed, var, var_lang], outputs=[translations_text, translations_audio, video_output])
    # button click to custom audio to video
    btn_2.click(audio_to_video_custom, inputs=[translations_audio], outputs=[video_output])
    # button click to custom test to audio
    btn_3.click(text_to_audio_custom, inputs=[translations_text, var_lang], outputs=[translations_audio])

demo.launch(show_api=False)