import gradio as gr
import os
import subprocess
import numpy as np
import torch
import torch.nn.functional as F
import librosa
import av
from transformers import VivitImageProcessor, VivitForVideoClassification
from transformers import AutoConfig, Wav2Vec2ForSequenceClassification, AutoFeatureExtractor
from moviepy.editor import VideoFileClip

def get_emotion_from_filename(filename):
    parts = filename.split('-')
    emotion_code = int(parts[2])
    emotion_labels = {
        1: 'neutral',
        3: 'happy',
        4: 'sad',
        5: 'angry',
        6: 'fearful',
        7: 'disgust'
    }
    return emotion_labels.get(emotion_code, None)

def separate_video_audio(file_path):
    output_dir = './temp/'
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)
    video_path = os.path.join(output_dir, os.path.basename(file_path).replace('.mp4', '_video.mp4'))
    audio_path = os.path.join(output_dir, os.path.basename(file_path).replace('.mp4', '_audio.wav'))

    video_cmd = ['ffmpeg', '-loglevel', 'quiet', '-i', file_path, '-an', '-c:v', 'libx264', '-preset', 'ultrafast', video_path]
    subprocess.run(video_cmd, check=True)

    audio_cmd = ['ffmpeg', '-loglevel', 'quiet', '-i', file_path, '-vn', '-acodec', 'pcm_s16le', '-ar', '16000', audio_path]
    subprocess.run(audio_cmd, check=True)

    return video_path, audio_path

def delete_files_in_directory(directory):
    for filename in os.listdir(directory):
        file_path = os.path.join(directory, filename)
        try:
            if os.path.isfile(file_path):
                os.remove(file_path)
        except Exception as e:
            print(f"Failed to delete {file_path}. Reason: {e}")

def get_total_frames(container):
    stream = container.streams.video[0]
    total_frames = stream.frames
    return total_frames

def process_video(file_path):
    container = av.open(file_path)
    total_frames = get_total_frames(container)
    
    if total_frames < 64:
        container.close()
        raise ValueError("Video must have at least 64 frames.")
    
    indices = sample_frame_indices(clip_len=32, frame_sample_rate=2, seg_len=total_frames)
    video = read_video_pyav(container=container, indices=indices)
    container.close()
    return video

def read_video_pyav(container, indices):
    frames = []
    container.seek(0)
    start_index = indices[0]
    end_index = indices[-1]
    for i, frame in enumerate(container.decode(video=0)):
        if i > end_index:
            break
        if i >= start_index and i in indices:
            frame = frame.reformat(width=224, height=224)
            frames.append(frame)
    return np.stack([x.to_ndarray(format="rgb24") for x in frames])

def sample_frame_indices(clip_len, frame_sample_rate, seg_len):
    converted_len = int(clip_len * frame_sample_rate)
    end_idx = np.random.randint(converted_len, seg_len)
    start_idx = end_idx - converted_len
    indices = np.linspace(start_idx, end_idx, num=clip_len)
    indices = np.clip(indices, start_idx, end_idx - 1).astype(np.int64)
    return indices

def video_label_to_emotion(label):
    label_map = {0: 'neutral', 1: 'happy', 2: 'sad', 3: 'angry', 4: 'fearful', 5: 'disgust'}
    label_index = int(label.split('_')[1])
    return label_map.get(label_index, "Unknown Label")

def predict_video(file_path, video_model, image_processor):
    video = process_video(file_path)
    inputs = image_processor(list(video), return_tensors="pt")
    device = torch.device("cpu")
    inputs = inputs.to(device)
    
    with torch.no_grad():
        outputs = video_model(**inputs)
        logits = outputs.logits
        probs = F.softmax(logits, dim=-1).squeeze()
    
    emotion_probabilities = {video_label_to_emotion(video_model.config.id2label[idx]): float(prob) for idx, prob in enumerate(probs)}
    return emotion_probabilities

def audio_label_to_emotion(label):
    label_map = {0: 'angry', 1: 'disgust', 2: 'fearful', 3: 'happy', 4: 'neutral', 5: 'sad'}
    label_index = int(label.split('_')[1])
    return label_map.get(label_index, "Unknown Label")

def preprocess_and_predict_audio(file_path, model, processor):
    audio_array, _ = librosa.load(file_path, sr=16000)
    inputs = processor(audio_array, sampling_rate=16000, return_tensors="pt", padding=True, max_length=75275)
    device = torch.device("cpu")
    model = model.to(device)
    inputs = {k: v.to(device) for k, v in inputs.items()}

    with torch.no_grad():
        output = model(**inputs)
        logits = output.logits
    probabilities = F.softmax(logits, dim=-1)
    emotion_probabilities = {audio_label_to_emotion(model.config.id2label[idx]): float(prob) for idx, prob in enumerate(probabilities[0])}
    return emotion_probabilities

def averaging_method(video_prediction, audio_prediction):
    combined_probabilities = {}
    for label in set(video_prediction) | set(audio_prediction):
        combined_probabilities[label] = (video_prediction.get(label, 0) + audio_prediction.get(label, 0)) / 2
    consensus_label = max(combined_probabilities, key=combined_probabilities.get)
    return consensus_label

def weighted_average_method(video_prediction, audio_prediction, video_weight): 
    audio_weight = 0.6  
    combined_probabilities = {}
    for label in set(video_prediction) | set(audio_prediction):
        video_prob = video_prediction.get(label, 0)
        audio_prob = audio_prediction.get(label, 0)
        combined_probabilities[label] = (video_weight * video_prob + audio_weight * audio_prob) / (video_weight + audio_weight)
    consensus_label = max(combined_probabilities, key=combined_probabilities.get)
    return consensus_label

def confidence_level_method(video_prediction, audio_prediction, threshold=0.7):
    highest_video_label = max(video_prediction, key=video_prediction.get)
    highest_video_confidence = video_prediction[highest_video_label]
    if (highest_video_confidence >= threshold):
        return highest_video_label
    combined_probabilities = {}
    for label in set(video_prediction) | set(audio_prediction):
        video_prob = video_prediction.get(label, 0)
        audio_prob = audio_prediction.get(label, 0)
        combined_probabilities[label] = (video_prob + audio_prob) / 2
    return max(combined_probabilities, key=combined_probabilities.get)

def dynamic_weighting_method(video_prediction, audio_prediction):
    combined_probabilities = {}
    for label in set(video_prediction) | set(audio_prediction):
        video_prob = video_prediction.get(label, 0)
        audio_prob = audio_prediction.get(label, 0)
        video_confidence = video_prob / sum(video_prediction.values())
        audio_confidence = audio_prob / sum(audio_prediction.values())
        video_weight = video_confidence / (video_confidence + audio_confidence)
        audio_weight = audio_confidence / (video_confidence + audio_confidence)
        combined_probabilities[label] = (video_weight * video_prob + audio_weight * audio_prob)
    return max(combined_probabilities, key=combined_probabilities.get)

def rule_based_method(video_prediction, audio_prediction, threshold=0.5):
    highest_video_label = max(video_prediction, key=video_prediction.get)
    highest_audio_label = max(audio_prediction, key=audio_prediction.get)
    video_confidence = video_prediction[highest_video_label] / sum(video_prediction.values())
    audio_confidence = audio_prediction[highest_audio_label] / sum(audio_prediction.values())
    combined_probabilities = {}
    for label in set(video_prediction) | set(audio_prediction):
        video_prob = video_prediction.get(label, 0)
        audio_prob = audio_prediction.get(label, 0)
        combined_probabilities[label] = (video_prob + audio_prob) / 2
    if (highest_video_label == highest_audio_label and video_confidence > threshold and audio_confidence > threshold):
        return highest_video_label
    elif video_confidence > audio_confidence:
        return highest_video_label
    elif audio_confidence > video_confidence:
        return highest_audio_label
    return max(combined_probabilities, key=combined_probabilities.get)

decision_frameworks = {
    "Averaging": averaging_method,
    "Weighted Average": weighted_average_method,
    "Confidence Level": confidence_level_method,
    "Dynamic Weighting": dynamic_weighting_method,
    "Rule-Based": rule_based_method
}

def predict(video_file, video_model_name, audio_model_name, framework_name):

    image_processor = VivitImageProcessor.from_pretrained("google/vivit-b-16x2-kinetics400")
    if video_model_name == "60% Accuracy":
        video_model = torch.load("video_model_60_acc.pth", map_location=torch.device('cpu'))
    elif video_model_name == "80% Accuracy":
        video_model = torch.load("video_model_80_acc.pth", map_location=torch.device('cpu'))

    model_id = "facebook/wav2vec2-large"
    config = AutoConfig.from_pretrained(model_id, num_labels=6)
    audio_processor = AutoFeatureExtractor.from_pretrained(model_id)
    audio_model = Wav2Vec2ForSequenceClassification.from_pretrained(model_id, config=config)
    if audio_model_name == "60% Accuracy":
        audio_model.load_state_dict(torch.load("audio_model_state_dict_6e.pth", map_location=torch.device('cpu')))
        audio_model.eval()

    delete_directory_path = "./temp/"

    try:
        video_path, audio_path = separate_video_audio(video_file)
        
        video_prediction = predict_video(video_path, video_model, image_processor)
    
        highest_video_emotion = max(video_prediction, key=video_prediction.get)
        
        audio_prediction = preprocess_and_predict_audio(audio_path, audio_model, audio_processor)
    
        highest_audio_emotion = max(audio_prediction, key=audio_prediction.get)
        
        framework_function = decision_frameworks[framework_name]

        if framework_function == weighted_average_method and video_model_name == "60% Accuracy":
            consensus_label = framework_function(video_prediction, audio_prediction, 0.6)
        elif framework_function == weighted_average_method and video_model_name == "80% Accuracy":
            consensus_label = framework_function(video_prediction, audio_prediction, 0.88)
        else:
            consensus_label = framework_function(video_prediction, audio_prediction)
    
        delete_files_in_directory(delete_directory_path)
    
        result = f"""
        <h2>Predictions</h2>
        <p><strong>Video Label:</strong> {highest_video_emotion}</p>
        <p><strong>Audio Label:</strong> {highest_audio_emotion}</p>
        <p><strong>Consensus Label:</strong> {consensus_label}</p>
        """
    
    except ValueError as e:
        result = f"""
        <h2>Error</h2>
        <p>{str(e)}</p>
        """
    
    return result

inputs = [
    gr.Video(label="Upload Video"),
    gr.Dropdown(["60% Accuracy", "80% Accuracy"], label="Select Video Model"),
    gr.Dropdown(["60% Accuracy"], label="Select Audio Model"),
    gr.Dropdown(list(decision_frameworks.keys()), label="Select Decision Framework")
]

outputs = [
    gr.HTML(label="Predictions")
]

iface = gr.Interface(
    fn=predict,
    inputs=inputs,
    outputs=outputs,
    examples=[
        ["./Angry.mp4", "60% Accuracy", "60% Accuracy", "Averaging"],
        ["./Disgust.mp4", "80% Accuracy", "60% Accuracy", "Weighted Average"],
        ["./Fearful.mp4", "60% Accuracy", "60% Accuracy", "Confidence Level"],
        ["./Happy.mp4", "80% Accuracy", "60% Accuracy", "Dynamic Weighting"],
        ["./Neutral.mp4", "80% Accuracy", "60% Accuracy", "Rule-Based"],
        ["./Sad.mp4", "60% Accuracy", "60% Accuracy", "Weighted Average"]
        ],
    title="Video and Audio Emotion Prediction",
    description="Upload a video to get emotion predictions from selected video and audio models. Example videos are from the RAVDESS dataset."
)

iface.launch(debug=True, share=True)