import gradio as gr
from collections import defaultdict
import os
import base64
from datasets import (
    Dataset,
    load_dataset,
)
import pandas as pd
from collections import defaultdict
import itertools

TOKEN = os.environ['TOKEN']


MASKED_LM_MODELS = [
    "BounharAbdelaziz/XLM-RoBERTa-Morocco",
    "SI2M-Lab/DarijaBERT",
    "BounharAbdelaziz/ModernBERT-Morocco",
    "google-bert/bert-base-multilingual-cased",
    "FacebookAI/xlm-roberta-large",
    "aubmindlab/bert-base-arabertv02",
]

CAUSAL_LM_MODELS = [
    "BounharAbdelaziz/Al-Atlas-LLM-0.5B",
    "Qwen/Qwen2.5-0.5B",
    "tiiuae/Falcon3-1B-Base",
    "MBZUAI-Paris/Atlas-Chat-2B",
]

def encode_image_to_base64(image_path):
    """Encode an image or GIF file to base64."""
    with open(image_path, "rb") as file:
        encoded_string = base64.b64encode(file.read()).decode()
    return encoded_string

def create_html_media(media_path, is_gif=False):
    """Create HTML for displaying an image or GIF."""
    media_base64 = encode_image_to_base64(media_path)
    media_type = "gif" if is_gif else "jpeg"
    
    html_string = f"""
    <div style="display: flex; justify-content: center; align-items: center; width: 100%; text-align: center;">
        <div style="max-width: 450px; margin: auto;">
            <img src="data:image/{media_type};base64,{media_base64}"
                 style="max-width: 75%; height: auto; display: block; margin: 0 auto; margin-top: 50px;"
                 alt="Displayed Media">
        </div>
    </div>
    """
    return html_string

class LMBattleArena:
    def __init__(self, dataset_path, saving_freq=25):
        """Initialize battle arena with dataset"""
        self.df = pd.read_csv(dataset_path)
        self.current_index = 0
        self.saving_freq = saving_freq  # save the results in csv/push to hub every saving_freq evaluations
        self.evaluation_results_masked = []
        self.evaluation_results_causal = []
        self.model_scores = defaultdict(lambda: {'wins': 0, 'total_comparisons': 0})
        
        # Generate all possible model pairs
        self.masked_model_pairs = list(itertools.combinations(MASKED_LM_MODELS, 2))
        self.causal_model_pairs = list(itertools.combinations(CAUSAL_LM_MODELS, 2))
                
        # Pair indices to track which pair is being evaluated
        self.masked_pair_idx = 0
        self.causal_pair_idx = 0
        
        # To track which rows have been evaluated for which model pairs
        self.row_model_pairs_evaluated = set()  # Using a simple set
    
    def get_next_battle_pair(self, is_causal):
        """Retrieve next pair of summaries for comparison ensuring all pairs are evaluated"""
        
        if self.current_index >= len(self.df):
            # Reset index to go through dataset again with remaining model pairs
            self.current_index = 0
            
            # If we've gone through all model pairs for all rows, we're done
            if is_causal and self.causal_pair_idx >= len(self.causal_model_pairs):
                return None
            elif not is_causal and self.masked_pair_idx >= len(self.masked_model_pairs):
                return None
        
        row = self.df.iloc[self.current_index]
        
        # Get the current model pair to evaluate
        if is_causal:
            # Check if we've evaluated all causal model pairs
            if self.causal_pair_idx >= len(self.causal_model_pairs):
                # Move to next row and reset pair index
                self.current_index += 1
                self.causal_pair_idx = 0
                # Try again with the next row
                return self.get_next_battle_pair(is_causal)
            
            model_pair = self.causal_model_pairs[self.causal_pair_idx]
            pair_key = f"{self.current_index}_causal_{self.causal_pair_idx}"
            
            # Check if this row-pair combination has been evaluated
            if pair_key in self.row_model_pairs_evaluated:
                # Move to next pair
                self.causal_pair_idx += 1
                return self.get_next_battle_pair(is_causal)
            
            # Mark this row-pair combination as evaluated
            self.row_model_pairs_evaluated.add(pair_key)
            # Move to next pair for next evaluation
            self.causal_pair_idx += 1
            
            # Check if we've gone through all pairs for this row
            if self.causal_pair_idx >= len(self.causal_model_pairs):
                # Reset pair index and move to next row for next evaluation
                self.causal_pair_idx = 0
                self.current_index += 1
        else:
            # Similar logic for masked models
            if self.masked_pair_idx >= len(self.masked_model_pairs):
                self.current_index += 1
                self.masked_pair_idx = 0
                return self.get_next_battle_pair(is_causal)
            
            model_pair = self.masked_model_pairs[self.masked_pair_idx]
            pair_key = f"{self.current_index}_masked_{self.masked_pair_idx}"
            
            if pair_key in self.row_model_pairs_evaluated:
                self.masked_pair_idx += 1
                return self.get_next_battle_pair(is_causal)
            
            self.row_model_pairs_evaluated.add(pair_key)
            self.masked_pair_idx += 1
            
            if self.masked_pair_idx >= len(self.masked_model_pairs):
                self.masked_pair_idx = 0
                self.current_index += 1
        
        # Prepare the battle data with the selected model pair
        battle_data = {
            'prompt': row['masked_sentence'] if not is_causal else row['causal_sentence'],
            'model_1': row[model_pair[0]],
            'model_2': row[model_pair[1]],
            'model1_name': model_pair[0],
            'model2_name': model_pair[1]
        }
        
        return battle_data
    
    def record_evaluation(self, preferred_models, input_text, output1, output2, model1_name, model2_name, is_causal):
        """Record user's model preference and update scores"""
        self.model_scores[model1_name]['total_comparisons'] += 1
        self.model_scores[model2_name]['total_comparisons'] += 1
        
        if preferred_models == "Both Good":
            self.model_scores[model1_name]['wins'] += 1
            self.model_scores[model2_name]['wins'] += 1
        elif preferred_models == "Model A":  # Maps to first model
            self.model_scores[model1_name]['wins'] += 1
        elif preferred_models == "Model B":  # Maps to second model
            self.model_scores[model2_name]['wins'] += 1
        # "Both Bad" case - no wins recorded
        
        evaluation = {
            'input_text': input_text,
            'output1': output1,
            'output2': output2,
            'model1_name': model1_name,
            'model2_name': model2_name,
            'preferred_models': preferred_models
        }
        if is_causal:
            self.evaluation_results_causal.append(evaluation)
        else:
            self.evaluation_results_masked.append(evaluation)
        
        # Calculate the total number of evaluations
        total_evaluations = len(self.evaluation_results_causal) + len(self.evaluation_results_masked)
        
        # Save results periodically
        if total_evaluations % self.saving_freq == 0:
            self.save_results()
            
        return self.get_model_scores_df(is_causal)
    
    def save_results(self):
        """Save the evaluation results to Hub and CSV"""
        results_df = self.get_model_scores_df(is_causal=True)  # Get the latest scores
        results_dataset = Dataset.from_pandas(results_df)
        results_dataset.push_to_hub('atlasia/Res-Moroccan-Darija-LLM-Battle-Al-Atlas', private=True, token=TOKEN)
        results_df.to_csv('human_eval_results.csv')
        
        # Also save the raw evaluation results
        masked_df = pd.DataFrame(self.evaluation_results_masked)
        causal_df = pd.DataFrame(self.evaluation_results_causal)
        
        if not masked_df.empty:
            masked_df.to_csv('masked_evaluations.csv')
        if not causal_df.empty:
            causal_df.to_csv('causal_evaluations.csv')
    
    def get_model_scores_df(self, is_causal):
        """Convert model scores to DataFrame"""
        scores_data = []
        for model, stats in self.model_scores.items():
            if is_causal:
                if model not in CAUSAL_LM_MODELS:
                    continue
            else:
                if model not in MASKED_LM_MODELS:
                    continue
            win_rate = (stats['wins'] / stats['total_comparisons'] * 100) if stats['total_comparisons'] > 0 else 0
            scores_data.append({
                'Model': model,
                'Wins': stats['wins'],
                'Total Comparisons': stats['total_comparisons'],
                'Win Rate (%)': round(win_rate, 2)
            })

        results_df = pd.DataFrame(scores_data)
        print("Generated DataFrame:\n", results_df)  # Debugging print
        
        # if 'Win Rate (%)' not in results_df.columns:
        #     raise ValueError("Win Rate (%) column is missing from DataFrame!")

        return results_df


def create_battle_arena(dataset_path, is_gif, is_causal):
    arena = LMBattleArena(dataset_path)
    
    def battle_round(is_causal):
        battle_data = arena.get_next_battle_pair(is_causal)
        
        if battle_data is None:
            return "All model pairs have been evaluated for all examples!", "", "", "", "", gr.DataFrame(visible=False)
        
        return (
            battle_data['prompt'], 
            battle_data['model_1'], 
            battle_data['model_2'],
            battle_data['model1_name'], 
            battle_data['model2_name'],
            gr.DataFrame(visible=True)
        )
    
    def submit_preference(input_text, output_1, output_2, model1_name, model2_name, preferred_models, is_causal):
        scores_df = arena.record_evaluation(
            preferred_models, input_text, output_1, output_2, model1_name, model2_name, is_causal
        )
        next_battle = battle_round(is_causal)
        return (*next_battle[:-1], scores_df)

    with gr.Blocks(css="footer{display:none !important}") as demo:
        # Rest of the code remains the same
        base_path = os.path.dirname(__file__)
        local_image_path = os.path.join(base_path, 'battle_leaderboard.gif')
        gr.HTML(create_html_media(local_image_path, is_gif=is_gif))
        
        with gr.Tabs():
            with gr.Tab("Masked LM Battle Arena"):
                gr.Markdown("# 🤖 Pretrained SmolLMs Battle Arena")
                
                # Use gr.State to store the boolean value without displaying it
                is_causal = gr.State(value=False)
                
                input_text = gr.Textbox(
                    label="Input prompt", 
                    interactive=False,
                )
                
                with gr.Row():
                    output_1 = gr.Textbox(
                        label="Model A", 
                        interactive=False
                    )
                    model1_name = gr.State()  # Hidden state for model1 name
                
                with gr.Row():
                    output_2 = gr.Textbox(
                        label="Model B", 
                        interactive=False
                    )
                    model2_name = gr.State()  # Hidden state for model2 name
                
                preferred_models = gr.Radio(
                    label="Which model is better?",
                    choices=["Model A", "Model B", "Both Good", "Both Bad"]
                )
                submit_btn = gr.Button("Vote", variant="primary")
                
                scores_table = gr.DataFrame(
                    headers=['Model', 'Wins', 'Total Comparisons', 'Win Rate (%)'],
                    label="🏆 Leaderboard"
                )
                
                submit_btn.click(
                    submit_preference,
                    inputs=[input_text, output_1, output_2, model1_name, model2_name, preferred_models, is_causal],
                    outputs=[input_text, output_1, output_2, model1_name, model2_name, scores_table]
                )
                
                demo.load(
                    battle_round, 
                    inputs=[is_causal],
                    outputs=[input_text, output_1, output_2, model1_name, model2_name, scores_table]
                )
                
            with gr.Tab("Causal LM Battle Arena"):
                gr.Markdown("# 🤖 Pretrained SmolLMs Battle Arena")
                
                # Use gr.State to store the boolean value without displaying it
                is_causal = gr.State(value=True)
                
                input_text = gr.Textbox(
                    label="Input prompt", 
                    interactive=False,
                )
                
                with gr.Row():
                    output_1 = gr.Textbox(
                        label="Model A", 
                        interactive=False
                    )
                    model1_name = gr.State()  # Hidden state for model1 name
                
                with gr.Row():
                    output_2 = gr.Textbox(
                        label="Model B", 
                        interactive=False
                    )
                    model2_name = gr.State()  # Hidden state for model2 name
                
                preferred_models = gr.Radio(
                    label="Which model is better?",
                    choices=["Model A", "Model B", "Both Good", "Both Bad"]
                )
                submit_btn = gr.Button("Vote", variant="primary")
                
                scores_table = gr.DataFrame(
                    headers=['Model', 'Wins', 'Total Comparisons', 'Win Rate (%)'],
                    label="🏆 Leaderboard"
                )
                
                submit_btn.click(
                    submit_preference,
                    inputs=[input_text, output_1, output_2, model1_name, model2_name, preferred_models, is_causal],
                    outputs=[input_text, output_1, output_2, model1_name, model2_name, scores_table]
                )
                
                demo.load(
                    battle_round, 
                    inputs=[is_causal],
                    outputs=[input_text, output_1, output_2, model1_name, model2_name, scores_table]
                )
                        
    return demo

if __name__ == "__main__":
    
    # inference device
    device = "cpu"
    dataset_path = 'human_eval_dataset.csv'
    is_gif = True
    
    # load the existing dataset that contains outputs of the LMs
    human_eval_dataset = load_dataset("atlasia/LM-Moroccan-Darija-Bench", split='test', token=TOKEN).to_csv(dataset_path) # atlasia/Moroccan-Darija-LLM-Battle-Al-Atlas

    demo = create_battle_arena(dataset_path, is_gif, is_causal=False)
    demo.launch(debug=True)