import gradio as gr
import torch
import random
import whisper
import re
from nemo.collections.asr.models import EncDecSpeakerLabelModel

# from transformers import Wav2Vec2Processor, Wav2Vec2Tokenizer


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


def audio_to_text(audio):
    model = whisper.load_model("base.en")

    audio = whisper.load_audio(audio)
    result = model.transcribe(audio)

    return result["text"]


random_sentences = [
    "the keep brown",
    "jump over table",
    "green mango fruit",
    "how much money",
    "please audio speaker",
    "nothing is better",
    "garden banana orange",
    "tiger animal king",
    "laptop mouse monitor"
]

additional_random_sentences = [
    "sunrise over mountains"
    "whispering gentle breeze"
    "garden of roses"
    "melodies in rain"
    "laughing with friends"
    "silent midnight moon"
    "skipping in meadow"
    "ocean waves crashing"
    "exploring hidden caves"
    "serenading under stars"
]


# Define a Gradio interface with text inputs for both speakers
def get_random_sentence():
    return random.choice(random_sentences)


text_inputs = [
    gr.inputs.Textbox(label="Speak the Words given below:", default=get_random_sentence, lines=1),
]

STYLE = """
<link rel="stylesheet" href="https://cdn.jsdelivr.net/npm/bootstrap@5.1.3/dist/css/bootstrap.min.css" integrity="sha256-YvdLHPgkqJ8DVUxjjnGVlMMJtNimJ6dYkowFFvp4kKs=" crossorigin="anonymous">
"""

OUTPUT_ERROR = (
        STYLE
        + """
    <div class="container">
        <div class="row"><h1 style="text-align: center">Spoken Words Did Not Match to the OTP, </h1></div>
        <div class="row"><h1 class="text-danger" style="text-align: center">Please Speak Clearly!!!!</h1></div>
        <div class="row"><h1 class="display-1 text-success" style="text-align: center">Words Spoken 1: {}</h1></div>
        <div class="row"><h1 class="display-1 text-success" style="text-align: center">Words Spoken 2: {}</h1></div>              
    </div>
"""
)

OUTPUT_OK = (
        STYLE
        + """
    <div class="container">
        <div class="row"><h1 style="text-align: center">The provided samples are</h1></div>
        <div class="row"><h1 class="text-success" style="text-align: center">Same Speakers!!!</h1></div>
        <div class="row"><h1 class="text-success" style="text-align: center">Authentication Successfull!!!</h1></div>

    </div>
"""
)
OUTPUT_FAIL = (
        STYLE
        + """
    <div class="container">
        <div class="row"><h1 style="text-align: center">The provided samples are from </h1></div>
        <div class="row"><h1 class="text-danger" style="text-align: center">Different Speakers!!!</h1></div> 
        <div class="row"><h1 class="text-danger" style="text-align: center">Authentication Failed!!!</h1></div>       
    </div>
"""
)

THRESHOLD = 0.80

model_name = "nvidia/speakerverification_en_titanet_large"
model = EncDecSpeakerLabelModel.from_pretrained(model_name).to(device)


def clean_sentence(sentence):
    # Remove commas and full stops using regular expression
    cleaned_sentence = re.sub(r'[,.?!]', '', sentence)
    # Convert the sentence to lowercase
    cleaned_sentence = cleaned_sentence.lower()
    cleaned_sentence = cleaned_sentence.strip()
    return cleaned_sentence


def compare_samples(text, path1, path2):
    if not (path1 and path2):
        return '<b style="color:red">ERROR: Please record audio for *both* speakers!</b>'

    cls1 = audio_to_text(path1)
    cls2 = audio_to_text(path2)
    
    myText = clean_sentence(text)
    Spoken1 = clean_sentence(cls1)
    Spoken2 = clean_sentence(cls2)

    print("OTP Given:", myText)
    print("Spoken 1:", Spoken1)
    print("Spoken 2:", Spoken2)

    if Spoken1 == Spoken2 == myText:
        embs1 = model.get_embedding(path1).squeeze()
        embs2 = model.get_embedding(path2).squeeze()

        # Length Normalize
        X = embs1 / torch.linalg.norm(embs1)
        Y = embs2 / torch.linalg.norm(embs2)

        # Score
        similarity_score = torch.dot(X, Y) / ((torch.dot(X, X) * torch.dot(Y, Y)) ** 0.5)
        similarity_score = (similarity_score + 1) / 2

        # Decision
        if similarity_score >= THRESHOLD:
            return OUTPUT_OK
        else:
            return OUTPUT_FAIL
    else:
        return OUTPUT_ERROR.format(Spoken1, Spoken2)


#
# def compare_samples1(path1, path2):
#     if not (path1 and path2):
#         return '<b style="color:red">ERROR: Please record audio for *both* speakers!</b>'
#
#     embs1 = model.get_embedding(path1).squeeze()
#     embs2 = model.get_embedding(path2).squeeze()
#
#     # Length Normalize
#     X = embs1 / torch.linalg.norm(embs1)
#     Y = embs2 / torch.linalg.norm(embs2)
#
#     # Score
#     similarity_score = torch.dot(X, Y) / ((torch.dot(X, X) * torch.dot(Y, Y)) ** 0.5)
#     similarity_score = (similarity_score + 1) / 2
#
#     # Decision
#     if similarity_score >= THRESHOLD:
#         return OUTPUT_OK.format(similarity_score * 100)
#     else:
#         return OUTPUT_FAIL.format(similarity_score * 100)


inputs = [
    *text_inputs,
    gr.inputs.Audio(source="microphone", type="filepath", optional=True, label="Speaker #1"),
    gr.inputs.Audio(source="microphone", type="filepath", optional=True, label="Speaker #2"),
]

# upload_inputs = [
#     gr.inputs.Audio(source="upload", type="filepath", optional=True, label="Speaker #1"),
#     gr.inputs.Audio(source="upload", type="filepath", optional=True, label="Speaker #2"),
# ]

description = (
    "Compare two speech samples and determine if they are from the same speaker."
)

microphone_interface = gr.Interface(
    fn=compare_samples,
    inputs=inputs,
    outputs=gr.outputs.HTML(label=""),
    title="Speaker Verification",
    description=description,
    layout="horizontal",
    theme="huggingface",
    allow_flagging=False,
    live=False,
)

# upload_interface = gr.Interface(
#     fn=compare_samples1,
#     inputs=upload_inputs,
#     outputs=gr.outputs.HTML(label=""),
#     title="Speaker Verification",
#     description=description,
#     layout="horizontal",
#     theme="huggingface",
#     allow_flagging=False,
#     live=False,
# )

demo = gr.TabbedInterface([microphone_interface, ], ["Microphone", ])
# demo = gr.TabbedInterface([microphone_interface, upload_interface], ["Microphone", "Upload File"])
demo.launch(enable_queue=True, share=True)