import re
import os
import nltk
import torch
import pickle
import torchaudio
import numpy as np
from TTS.tts.models.xtts import Xtts
from nltk.tokenize import sent_tokenize
from TTS.tts.configs.xtts_config import XttsConfig


def _load_array(filename):
    """ Opens a file a returns it, used with numpy files """
    with open(filename, 'rb') as f:
        return pickle.load(f)


os.environ['COQUI_TOS_AGREED'] = '1'

# Used to generate audio based on a sample
nltk.download('punkt')
model_path = os.path.join("tts_model")

config = XttsConfig()
config.load_json(os.path.join(model_path, "config.json"))

model = Xtts.init_from_config(config)
model.load_checkpoint(
    config,
    checkpoint_path=os.path.join(model_path, "model.pth"),
    vocab_path=os.path.join(model_path, "vocab.json"),
    eval=True,
    use_deepspeed=True,
)

device = 'cuda' if torch.cuda.is_available() else 'cpu'
model.to(device)

# Speaker latent
path_latents = 'assets/gpt_cond_latent.npy'
gpt_cond_latent = _load_array(path_latents)

# Speaker embedding
path_embedding = 'assets/speaker_embedding.npy'
speaker_embedding = _load_array(path_embedding)


def get_audio(text: str, language: str = 'es', saving_path: str = 'output') -> None:
    """
    Creates an audio
    :param text: text to convert to audio
    :param language: 'es', 'en' or 'pt', language used for the audio file
    :param saving_path: path to save the audio
    :return: None
    """
    # Creates an audio with the answer and saves it as output.wav
    _save_audio(text, language, saving_path)
    return


def _save_audio(text: str, language: str, path_audio: str) -> None:
    """
    Splits the text into sentences, clean and creates an audio for each one, then concatenates
    all the audios and saves them into a file.
    :param text: input text
    :param language: language used in the audio
    :param path_audio: saving path of the audio
    :return: None
    """
    # Split the answer into sentences and clean it
    sentences = _get_clean_text(text, language)

    # Get the voice of each sentence
    audio_segments = []
    for sentence in sentences:
        audio_stream = _get_voice(sentence, language)
        audio_stream = torch.tensor(audio_stream)
        audio_segments.append(audio_stream)

    # Concatenate and save all audio segments
    concatenated_audio = torch.cat(audio_segments, dim=0)
    torchaudio.save(f'{path_audio}.wav', concatenated_audio.unsqueeze(0), 24000)
    return


def _get_voice(sentence: str, language: str) -> np.ndarray:
    """
    Gets a numpy array with a wav of an audio with the given sentence and language
    :param sentence: input sentence
    :param language: languages used in the audio
    :return: numpy array with the audio
    """
    out = model.inference(
        sentence,
        language=language,
        gpt_cond_latent=gpt_cond_latent,
        speaker_embedding=speaker_embedding,
        temperature=0.1
    )
    return out['wav']


def _get_clean_text(text: str, language: str) -> list[str]:
    """
    Splits the text into smaller sentences using nltk and removes links.
    :param text: input text for the audio
    :param language: language used for the audio ('es', 'en', 'pt')
    :return: list of sentences
    """
    # Remove the links in the audio and add another sentence
    if language == 'en':
        clean_answer = re.sub(r'http[s]?://\S+', 'the following link', text)
        max_characters = 250
    elif language == 'es':
        clean_answer = re.sub(r'http[s]?://\S+', 'el siguiente link', text)
        max_characters = 239
    else:
        clean_answer = re.sub(r'http[s]?://\S+', 'o seguinte link', text)
        max_characters = 203

    # Change the name from Bella to Bela
    clean_answer = clean_answer.replace('Bella', 'Bela')

    # Remove Florida and zipcode
    clean_answer = re.sub(r', FL \d+', "", clean_answer)

    # Split the answer into sentences with nltk and make sure they are shorter than the maximum possible
    # characters
    split_sentences = sent_tokenize(clean_answer)
    sentences = []
    for sentence in split_sentences:
        if len(sentence) > max_characters:
            sentences.extend(_split_sentence(sentence, max_characters))
        else:
            sentences.append(sentence)

    return sentences


def _split_sentence(sentence: str, max_characters: int) -> list[str]:
    """
    Used when the sentences are still to long. The split point is the nearest comma to the middle
    of the sentence, if there is no comma then a space is used or just the middle. If the
    remaining sentences are still too long, another iteration is run.
    :param sentence: sentence to be split
    :param max_characters: max number of characters a sentence can have
    :return: list of sentences
    """
    # Get index of each comma
    sentences = []
    commas = [i for i, c in enumerate(sentence) if c == ',']

    # No commas, search for spaces
    if len(commas) == 0:
        commas = [i for i, c in enumerate(sentence) if c == ' ']

    # No commas or spaces, split it in the middle
    if len(commas) == 0:
        sentences.append(sentence[:len(sentence) // 2])
        sentences.append(sentence[len(sentence) // 2:])
        return sentences

    # Nearest index to the middle
    split_point = min(commas, key=lambda x: abs(x - (len(sentence) // 2)))

    if sentence[split_point] == ',':
        left = sentence[:split_point]
        right = sentence[split_point + 2:]
    else:
        left = sentence[:split_point]
        right = sentence[split_point + 1:]

    if len(left) > max_characters:
        sentences.extend(_split_sentence(left, max_characters))
    else:
        sentences.append(left)
    if len(right) > max_characters:
        sentences.extend(_split_sentence(right, max_characters))
    else:
        sentences.append(right)

    return sentences