from PhantomNET import PhantomNet
import joblib
from transformers import AutoFeatureExtractor, Wav2Vec2Model
import torch
import librosa
import numpy as np
from sklearn.linear_model import LogisticRegression
import gradio as gr
import yt_dlp as youtube_dl
import os

class HuggingFaceFeatureExtractor:
    def __init__(self, model_class, name):
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        self.feature_extractor = AutoFeatureExtractor.from_pretrained(name)
        self.model = model_class.from_pretrained(name, output_hidden_states=True)
        self.model.eval()
        self.model.to(self.device)

    def __call__(self, audio, sr):
        inputs = self.feature_extractor(
            audio,
            sampling_rate=sr,
            return_tensors="pt",
            padding=True,
        )
        inputs = {k: v.to(self.device) for k, v in inputs.items()}
        with torch.no_grad():
            outputs = self.model(**inputs)
        return outputs.hidden_states[9], outputs.hidden_states[8], outputs.last_hidden_state


FEATURE_EXTRACTOR = {"wav2vec2-xls-r-2b": lambda: HuggingFaceFeatureExtractor(Wav2Vec2Model,
                                                                               "facebook/wav2vec2-xls-r-2b")}


model1 = joblib.load('model1_ensemble.pkl')
model2 = joblib.load('model2_ensemble.pkl')
model3 = joblib.load('model3_ensemble.pkl')
model4 = joblib.load('model4_ensemble.pkl')
final_model = joblib.load('final_model_ensemble.pkl')

# def download_audio_from_youtube(youtube_url, output_path='.'):
#     ydl_opts = {
#         'format': 'bestaudio/best',
#         'outtmpl': f'{output_path}/%(title)s.%(ext)s',
#         'postprocessors': [{
#             'key': 'FFmpegExtractAudio',
#             'preferredcodec': 'wav',
#             'preferredquality': '192',
#         }],
#         'postprocessor_args': ['-ar', '16000'],
#         'prefer_ffmpeg': True,
#     }

#     with youtube_dl.YoutubeDL(ydl_opts) as ydl:
#         info_dict = ydl.extract_info(youtube_url, download=True)
#         #i have issues with the .webm extension, force replace with .wav
#         audio_file = ydl.prepare_filename(info_dict).replace('.webm', '.wav')
#     return audio_file

def download_audio_from_youtube(youtube_url, output_path='.', cookies_file='cookies.txt'):
    ydl_opts = {
        'format': 'bestaudio/best',
        'outtmpl': f'{output_path}/%(title)s.%(ext)s',
        'postprocessors': [{
            'key': 'FFmpegExtractAudio',
            'preferredcodec': 'wav',
            'preferredquality': '192',
        }],
        'postprocessor_args': ['-ar', '16000'],
        'prefer_ffmpeg': True,
    }

    cookies_content = os.getenv('cookies')
    with open('cookies.txt', 'w') as file:
        file.write(cookies_content)

    if cookies_file:
        ydl_opts['cookiefile'] = cookies_file

    with youtube_dl.YoutubeDL(ydl_opts) as ydl:
        info_dict = ydl.extract_info(youtube_url, download=True)
        # force replace with .wav because webm is not supported by librosa
        audio_file = ydl.prepare_filename(info_dict).replace('.webm', '.wav')
    if os.path.exists('cookies.txt'):
        os.remove('cookies.txt')
    return audio_file

def segment_audio(audio, sr, segment_duration):
    segment_samples = int(segment_duration * sr)
    total_samples = len(audio)
    segments = [audio[i:i + segment_samples] for i in range(0, total_samples, segment_samples)]
    return segments

def classify_with_eer_threshold(probabilities, eer_thresh):
    return (probabilities >= eer_thresh).astype(int)

def process_audio(input_data, segment_duration=3):
    if input_data.startswith("http"):
        file_audio = download_audio_from_youtube(input_data)
    else:
        file_audio = input_data

    audio, sr = librosa.load(file_audio, sr=16000)
    if len(audio.shape) > 1:
        audio = audio[0]

    segments = segment_audio(audio, sr, segment_duration)
    all_embeddings_layer10 = []
    all_embeddings_layer9 = []
    all_embeddings_layer48 = []

    # wav2vec2 extractor
    a = FEATURE_EXTRACTOR['wav2vec2-xls-r-2b']()
    for idx, segment in enumerate(segments):
        p1, p2, p3 = a(segment, sr)
        all_embeddings_layer10.append(p1)
        all_embeddings_layer9.append(p2)
        all_embeddings_layer48.append(p3)

    embedding_layer10 = torch.cat(all_embeddings_layer10, dim=1)
    embedding_layer9 = torch.cat(all_embeddings_layer9, dim=1)
    embedding_layer48 = torch.cat(all_embeddings_layer48, dim=1)

    wav2vec2_feature_layer10 = torch.mean(embedding_layer10, dim=1).cpu().numpy()
    wav2vec2_feature_layer9 = torch.mean(embedding_layer9, dim=1).cpu().numpy()
    wav2vec2_feature_layer48 = torch.mean(embedding_layer48, dim=1).cpu().numpy()

    # PhantomNet extractor
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = PhantomNet(feature_size=1920, num_classes=2, conv_projection=False, use_mode='extractor').to(device)
    state_dict = torch.load("PhantomNet_Finetuned_V2.pt", map_location=device)
    model.load_state_dict(state_dict, strict=False)
    model.eval()

    all_embeddings_PhantomNet = []
    for idx, segment in enumerate(segments):
        segment_input = torch.Tensor(segment).unsqueeze(0).to(device)
        p = model(segment_input).detach()
        all_embeddings_PhantomNet.append(p)

    embedding_PhantomNet = torch.cat(all_embeddings_PhantomNet, dim=1)
    PhantomNet_feature = torch.mean(embedding_PhantomNet, dim=1)

    wav2vec2_feature_layer9 = wav2vec2_feature_layer9.reshape(1, -1)
    wav2vec2_feature_layer10 = wav2vec2_feature_layer10.reshape(1, -1)
    wav2vec2_feature_layer48 = wav2vec2_feature_layer48.reshape(1, -1)
    PhantomNet_feature = PhantomNet_feature.reshape(1, -1)

    eval_prob1 = model1.predict_proba(wav2vec2_feature_layer10)[:, 1].reshape(-1, 1)
    eval_prob2 = model2.predict_proba(wav2vec2_feature_layer9)[:, 1].reshape(-1, 1)
    eval_prob3 = model3.predict_proba(wav2vec2_feature_layer48)[:, 1].reshape(-1, 1)
    eval_prob4 = model4.predict_proba(PhantomNet_feature)[:, 1].reshape(-1, 1)

    eval_combined_probs = np.hstack((eval_prob1, eval_prob2, eval_prob3, eval_prob4))

    eer_thresh = 0.02 # eer during evaluation

    final_prob = final_model.predict_proba(eval_combined_probs)[:, 1]

    y_pred_inference = classify_with_eer_threshold(final_prob, eer_thresh)

    if y_pred_inference == 1:
        return f"Fake with a confidence of: {final_prob[0] * 100:.2f}%"
    else:
        return f"Real with a confidence of: {100 - final_prob[0] * 100:.2f}%"

def gradio_interface(audio, youtube_link):
    if youtube_link:
        return process_audio(youtube_link)
    elif audio:
        return process_audio(audio)
    else:
        return "please upload audio or provide a YouTube link."

interface = gr.Interface(
    fn=gradio_interface,
    inputs=[gr.Audio(type="filepath", label="Upload Audio"), gr.Textbox(label="YouTube Link (Optional)")],
    outputs="text",
    title="AI4TRUST Development",
    description="Upload an audio file or provide a YouTube link to check if it's AI generated",
)

interface.launch(share=True)