import gradio as gr
import tempfile
from TTS.utils.synthesizer import Synthesizer
from huggingface_hub import hf_hub_download
import torch

CUDA = torch.cuda.is_available()

REPO_ID = "collectivat/catotron-ona"

my_title = "Catotron Text-to-Speech"
my_description = "This model is based on Fast Speech implemented in 🐸 [Coqui.ai](https://coqui.ai/)." 

my_examples = [
  ["Catotron, sintesi de la parla obert i lliure en catalĂ ."],
  ["Leonor Ferrer Girabau va ser una delineant, mestra i activista barcelonina, nascuda al carrer actual de la Concòrdia del Poble-sec, que es va convertir en la primera dona a obtenir el títol de delineant a Catalunya i a l’estat."],
  ["S'espera un dia anticiclònic amb temperatures suaus i vent fluix."]
]

my_inputs = [
  gr.Textbox(lines=5, label="Input Text")
]

my_outputs = gr.Audio(type="filepath", label="Output Audio")

def tts(text: str, split_sentences: bool = True):
    best_model_path = hf_hub_download(repo_id=REPO_ID, filename="fast-speech_best_model.pth") 
    config_path = hf_hub_download(repo_id=REPO_ID, filename="fast-speech_config.json")
    vocoder_model = hf_hub_download(repo_id=REPO_ID, filename="ljspeech--hifigan_v2_model_file.pth")
    vocoder_config = hf_hub_download(repo_id=REPO_ID, filename="ljspeech--hifigan_v2_config.json")
    
    synthesizer = Synthesizer(
        tts_checkpoint=best_model_path,
        tts_config_path=config_path,
        tts_speakers_file=None,
        tts_languages_file=None,
        vocoder_checkpoint=vocoder_model,
        vocoder_config=vocoder_config,
        encoder_checkpoint="",
        encoder_config="",
        use_cuda=CUDA
    )


    # replace oov characters
    text = text.replace("\n", ". ")
    text = text.replace("(", ",")
    text = text.replace(")", ",")
    text = text.replace(";", ",")

    # create audio file
    wavs = synthesizer.tts(text, split_sentences=split_sentences)
    with tempfile.NamedTemporaryFile(suffix = ".wav", delete = False) as fp:
        synthesizer.save_wav(wavs, fp)                      
    return fp.name 
 
iface = gr.Interface(
    fn=tts, 
    inputs=my_inputs, 
    outputs=my_outputs, 
    title=my_title, 
    description = my_description, 
    examples = my_examples,
    cache_examples=True
)
iface.launch()