import gradio as gr
import pandas as pd
import os
import zipfile
import base64

CITATION_BUTTON_LABEL = "Copy the following snippet to cite these results"
CITATION_BUTTON_TEXT = r"""@misc{aienergyscore-leaderboard,
    author = {Sasha Luccioni and Boris Gamazaychikov and Emma Strubell and Sara Hooker and Yacine Jernite and Carole-Jean Wu and Margaret Mitchell},
    title = {AI Energy Score Leaderboard - February 2025},
    year = {2025},
    publisher = {Hugging Face},
    howpublished = "\url{https://huggingface.co/spaces/AIEnergyScore/Leaderboard}",
}"""

# List of tasks (CSV filenames)
tasks = [
    'asr.csv',
    'object_detection.csv',
    'text_classification.csv',
    'image_captioning.csv',
    'question_answering.csv',
    'text_generation.csv',
    'image_classification.csv',
    'sentence_similarity.csv',
    'image_generation.csv',
    'summarization.csv'
]

### HELPER FUNCTIONS ###

def format_stars(score):
    try:
        score_int = int(score)
    except Exception:
        score_int = 0
    return f'<span style="color: #3fa45bff; font-size:1.5em;">{"★" * score_int}</span>'

def make_link(mname):
    parts = str(mname).split('/')
    display_name = parts[1] if len(parts) > 1 else mname
    return f'<a href="https://huggingface.co/{mname}" target="_blank">{display_name}</a>'

def extract_link_text(html_link):
    start = html_link.find('>') + 1
    end = html_link.rfind('</a>')
    if start > 0 and end > start:
        return html_link[start:end]
    else:
        return html_link

def generate_html_table_from_df(df):
    # Compute a static width for the Model column based on the longest model name.
    if not df.empty:
        max_length = max(len(extract_link_text(link)) for link in df['Model'])
    else:
        max_length = 10
    static_width = max_length * 10 + 16

    max_energy = df['gpu_energy_numeric'].max() if not df.empty else 1
    color_map = {"1": "black", "2": "black", "3": "black", "4": "black", "5": "black"}
    html = '<table class="data-table" style="width:100%; border-collapse: collapse; font-family: Inter, sans-serif;">'
    html += '<thead><tr style="background-color: #f2f2f2;">'
    html += '<th style="text-align: left; padding: 8px;" title="Model name with link to Hugging Face">Model</th>'
    html += '<th style="text-align: left; padding: 8px;" title="AI Provider extracted from the model name">Provider</th>'
    html += '<th style="text-align: left; padding: 8px;" title="GPU energy consumed in Watt-hours for 1,000 queries">GPU Energy (Wh)</th>'
    html += '<th style="text-align: left; padding: 8px;" title="Energy efficiency score (stars)">Score</th>'
    html += '</tr></thead>'
    html += '<tbody>'
    for _, row in df.iterrows():
        energy_numeric = row['gpu_energy_numeric']
        energy_str = f"{energy_numeric:,.2f}"
        bar_width = (energy_numeric / max_energy) * 100
        score_val = row['energy_score']
        bar_color = color_map.get(str(score_val), "gray")
        html += '<tr>'
        html += f'<td style="padding: 8px; width: {static_width}px;">{row["Model"]}</td>'
        html += f'<td style="padding: 8px;">{row["Provider"]}</td>'
        html += (f'<td style="padding: 8px;">{energy_str}<br>'
                 f'<div style="background-color: {bar_color}; width: {bar_width:.1f}%; height: 10px;"></div></td>')
        html += f'<td style="padding: 8px;">{row["Score"]}</td>'
        html += '</tr>'
    html += '</tbody></table>'
    return f'<div class="table-container">{html}</div>'

def process_df(task, sort_order="Low to High", filter_fn=None):
    df = pd.read_csv(os.path.join("data", "energy", task))
    if df.columns[0].startswith("Unnamed:"):
        df = df.iloc[:, 1:]
    df['energy_score'] = df['energy_score'].astype(int)
    df['gpu_energy_numeric'] = pd.to_numeric(df['total_gpu_energy'], errors='raise') * 1000
    if filter_fn is not None:
        df = filter_fn(df)
    df['Provider'] = df['model'].apply(lambda x: str(x).split('/')[0])
    df['Model'] = df['model'].apply(make_link)
    df['Score'] = df['energy_score'].apply(format_stars)
    ascending = True if sort_order == "Low to High" else False
    df = df.sort_values(by='gpu_energy_numeric', ascending=ascending)
    return df

def compute_efficiency_ratio(df):
    if df.empty:
        return 1
    min_val = df['gpu_energy_numeric'].min()
    max_val = df['gpu_energy_numeric'].max()
    ratio = max_val / min_val if min_val > 0 else 1
    return ratio

def generate_info_callout(ratio, scope_text):
    return (
        f'<div style="text-align: right;">'
        f'<div class="info-callout" style="display:inline-block; max-width:250px; font-size:0.8em; background-color:#e6ffe6; padding:8px; border-radius:5px;">'
        f'💡 There\'s a <strong style="color: black !important;">{ratio:,.1f}x</strong> difference between the highest and lowest energy use in {scope_text}.'
        f'</div></div>'
    )

def get_global_callout():
    all_df = pd.DataFrame()
    for task in tasks:
        df = pd.read_csv(os.path.join("data", "energy", task))
        if df.columns[0].startswith("Unnamed:"):
            df = df.iloc[:, 1:]
        df['gpu_energy_numeric'] = pd.to_numeric(df['total_gpu_energy'], errors='raise') * 1000
        all_df = pd.concat([all_df, df], ignore_index=True)
    ratio = compute_efficiency_ratio(all_df)
    return generate_info_callout(ratio, "this leaderboard")

### ZIP DOWNLOAD FUNCTIONS ###

def zip_csv_files():
    data_dir = os.path.join("data", "energy")
    zip_filename = "data.zip"
    with zipfile.ZipFile(zip_filename, "w", zipfile.ZIP_DEFLATED) as zipf:
        for filename in os.listdir(data_dir):
            if filename.endswith(".csv"):
                filepath = os.path.join(data_dir, filename)
                zipf.write(filepath, arcname=filename)
    return zip_filename

def get_zip_data_link():
    zip_filename = zip_csv_files()
    with open(zip_filename, "rb") as f:
        data = f.read()
    b64 = base64.b64encode(data).decode()
    href = (
        f'<a class="header-link" href="data:application/zip;base64,{b64}" '
        'download="data.zip" '
        'style="text-decoration: none; font-weight: bold; font-size: 1.1em; color: inherit; font-family: \'Inter\', sans-serif;">Download Data</a>'
    )
    return href

### UPDATE FUNCTIONS (RETURNING CALLOUT AND TABLE HTML) ###

def update_text_generation(selected_display, sort_order):
    mapping = {
        "A (Single Consumer GPU) <20B parameters": "A",
        "B (Single Cloud GPU) 20-66B parameters": "B",
        "C (Multiple Cloud GPUs) >66B parameters": "C"
    }
    model_class = mapping.get(selected_display, "A")
    def filter_fn(df):
        if 'class' in df.columns:
            return df[df['class'] == model_class]
        return df
    df = process_df('text_generation.csv', sort_order, filter_fn)
    ratio = compute_efficiency_ratio(df)
    # For Text Generation, use "this class" as the scope.
    callout = generate_info_callout(ratio, "this class")
    table_html = generate_html_table_from_df(df)
    return callout, table_html

def update_image_generation(sort_order):
    df = process_df('image_generation.csv', sort_order)
    ratio = compute_efficiency_ratio(df)
    callout = generate_info_callout(ratio, "this task")
    table_html = generate_html_table_from_df(df)
    return callout, table_html

def update_text_classification(sort_order):
    df = process_df('text_classification.csv', sort_order)
    ratio = compute_efficiency_ratio(df)
    callout = generate_info_callout(ratio, "this task")
    table_html = generate_html_table_from_df(df)
    return callout, table_html

def update_image_classification(sort_order):
    df = process_df('image_classification.csv', sort_order)
    ratio = compute_efficiency_ratio(df)
    callout = generate_info_callout(ratio, "this task")
    table_html = generate_html_table_from_df(df)
    return callout, table_html

def update_image_captioning(sort_order):
    df = process_df('image_captioning.csv', sort_order)
    ratio = compute_efficiency_ratio(df)
    callout = generate_info_callout(ratio, "this task")
    table_html = generate_html_table_from_df(df)
    return callout, table_html

def update_summarization(sort_order):
    df = process_df('summarization.csv', sort_order)
    ratio = compute_efficiency_ratio(df)
    callout = generate_info_callout(ratio, "this task")
    table_html = generate_html_table_from_df(df)
    return callout, table_html

def update_asr(sort_order):
    df = process_df('asr.csv', sort_order)
    ratio = compute_efficiency_ratio(df)
    callout = generate_info_callout(ratio, "this task")
    table_html = generate_html_table_from_df(df)
    return callout, table_html

def update_object_detection(sort_order):
    df = process_df('object_detection.csv', sort_order)
    ratio = compute_efficiency_ratio(df)
    callout = generate_info_callout(ratio, "this task")
    table_html = generate_html_table_from_df(df)
    return callout, table_html

def update_sentence_similarity(sort_order):
    df = process_df('sentence_similarity.csv', sort_order)
    ratio = compute_efficiency_ratio(df)
    callout = generate_info_callout(ratio, "this task")
    table_html = generate_html_table_from_df(df)
    return callout, table_html

def update_extractive_qa(sort_order):
    df = process_df('question_answering.csv', sort_order)
    ratio = compute_efficiency_ratio(df)
    callout = generate_info_callout(ratio, "this task")
    table_html = generate_html_table_from_df(df)
    return callout, table_html

def update_all_tasks(sort_order):
    all_df = pd.DataFrame()
    for task in tasks:
        df = pd.read_csv(os.path.join("data", "energy", task))
        if df.columns[0].startswith("Unnamed:"):
            df = df.iloc[:, 1:]
        df['energy_score'] = df['energy_score'].astype(int)
        df['gpu_energy_numeric'] = pd.to_numeric(df['total_gpu_energy'], errors='raise') * 1000
        df['Provider'] = df['model'].apply(lambda x: str(x).split('/')[0])
        df['Model'] = df['model'].apply(make_link)
        df['Score'] = df['energy_score'].apply(format_stars)
        all_df = pd.concat([all_df, df], ignore_index=True)
    all_df = all_df.drop_duplicates(subset=['model'])
    ascending = True if sort_order == "Low to High" else False
    all_df = all_df.sort_values(by='gpu_energy_numeric', ascending=ascending)
    ratio = compute_efficiency_ratio(all_df)
    callout = generate_info_callout(ratio, "this leaderboard")
    table_html = generate_html_table_from_df(all_df)
    return callout, table_html

### GLOBAL HEADER (Logo & Global Callout) ###

# Use a <picture> element so that dark mode uses logodark.png.
global_header_html = f"""
<div style="position: relative; width: 100%; text-align: center; margin-bottom: 20px;">
  <picture style="display:inline-block;">
    <source media="(prefers-color-scheme: dark)" srcset="https://huggingface.co/spaces/AIEnergyScore/Leaderboard/resolve/main/logodark.png">
    <img src="https://huggingface.co/spaces/AIEnergyScore/Leaderboard/resolve/main/logo.png" 
         alt="Logo" 
         style="width:300px; max-width:300px; height:auto; display:inline-block;">
  </picture>
  <div style="position: absolute; top: 50%; right: 20px; transform: translateY(-50%);">
      {get_global_callout()}
  </div>
</div>
"""

### CUSTOM CSS for Dark Mode and Mobile Responsiveness ###
custom_css = """
/* Table and layout */
.data-table {
    table-layout: fixed;
    width: 100%;
}
.data-table th, .data-table td {
    max-width: 150px;
    white-space: nowrap;
    overflow: hidden;
    text-overflow: ellipsis;
}
.table-container {
    width: 100%;
    margin-left: auto;
    margin-right: auto;
}

/* Force header links to be black in light mode */
.header-link {
    color: black !important;
}

/* Dark mode styles */
@media (prefers-color-scheme: dark) {
    body {
        background-color: #121212;
        color: #e0e0e0;
    }
    /* Make header links white */
    .header-link {
         color: white !important;
    }
    .data-table thead {
        background-color: #333;
    }
    /* Make table header text black in dark mode */
    .data-table th {
        color: black !important;
    }
    .data-table td {
        color: #e0e0e0;
    }
    /* Make callout text black */
    .info-callout {
        color: black !important;
    }
    /* Non-header links in dark mode */
    a:not(.header-link) {
        color: #3fa45bff !important;
    }
}

/* Mobile styles: hide callout boxes on small screens */
@media (max-width: 600px) {
    .info-callout {
        display: none !important;
    }
}
"""

### GRADIO INTERFACE ###
demo = gr.Blocks(css=custom_css)

with demo:
    # --- Header Links ---
    gr.HTML(f"""
    <div style="display: flex; justify-content: space-evenly; align-items: center; margin-bottom: 20px;">
        <a class="header-link" href="https://huggingface.co/spaces/AIEnergyScore/submission_portal" style="text-decoration: none; font-weight: bold; font-size: 1.1em;">Submission Portal</a>
        <a class="header-link" href="https://huggingface.co/spaces/AIEnergyScore/Label" style="text-decoration: none; font-weight: bold; font-size: 1.1em;">Label Generator</a>
        <a class="header-link" href="https://huggingface.github.io/AIEnergyScore/#faq" style="text-decoration: none; font-weight: bold; font-size: 1.1em;">FAQ</a>
        <a class="header-link" href="https://huggingface.github.io/AIEnergyScore/#documentation" style="text-decoration: none; font-weight: bold; font-size: 1.1em;">Documentation</a>
        {get_zip_data_link()}
        <a class="header-link" href="https://huggingface.co/spaces/AIEnergyScore/README/discussions" style="text-decoration: none; font-weight: bold; font-size: 1.1em;">Community</a>
    </div>
    """)
    
    # --- Global Header: Centered Logo with Global Callout at Right Edge ---
    gr.HTML(global_header_html)
    
    # --- Tabs for the different tasks ---
    with gr.Tabs():
        # --- Text Generation Tab ---
        with gr.TabItem("Text Generation 💬"):
            with gr.Row():
                with gr.Column(scale=4):
                    model_class_options = [
                        "A (Single Consumer GPU) <20B parameters",
                        "B (Single Cloud GPU) 20-66B parameters",
                        "C (Multiple Cloud GPUs) >66B parameters"
                    ]
                    model_class_dropdown = gr.Dropdown(choices=model_class_options, label="Select Model Class", value=model_class_options[0])
                with gr.Column(scale=4):
                    sort_dropdown_tg = gr.Dropdown(choices=["Low to High", "High to Low"], label="Sort", value="Low to High")
                with gr.Column(scale=4):
                    tg_callout = gr.HTML()
            tg_table = gr.HTML()
            init_callout, init_table = update_text_generation(model_class_options[0], "Low to High")
            tg_callout.value = init_callout
            tg_table.value = init_table
            model_class_dropdown.change(fn=update_text_generation, inputs=[model_class_dropdown, sort_dropdown_tg], outputs=[tg_callout, tg_table])
            sort_dropdown_tg.change(fn=update_text_generation, inputs=[model_class_dropdown, sort_dropdown_tg], outputs=[tg_callout, tg_table])
        
        # --- Image Generation Tab ---
        with gr.TabItem("Image Generation 📷"):
            with gr.Row():
                with gr.Column(scale=8):
                    sort_dropdown_img = gr.Dropdown(choices=["Low to High", "High to Low"], label="Sort", value="Low to High")
                with gr.Column(scale=4):
                    img_callout = gr.HTML()
            img_table = gr.HTML()
            init_callout, init_table = update_image_generation("Low to High")
            img_callout.value = init_callout
            img_table.value = init_table
            sort_dropdown_img.change(fn=update_image_generation, inputs=sort_dropdown_img, outputs=[img_callout, img_table])
        
        # --- Text Classification Tab ---
        with gr.TabItem("Text Classification 🎭"):
            with gr.Row():
                with gr.Column(scale=8):
                    sort_dropdown_tc = gr.Dropdown(choices=["Low to High", "High to Low"], label="Sort", value="Low to High")
                with gr.Column(scale=4):
                    tc_callout = gr.HTML()
            tc_table = gr.HTML()
            init_callout, init_table = update_text_classification("Low to High")
            tc_callout.value = init_callout
            tc_table.value = init_table
            sort_dropdown_tc.change(fn=update_text_classification, inputs=sort_dropdown_tc, outputs=[tc_callout, tc_table])
        
        # --- Image Classification Tab ---
        with gr.TabItem("Image Classification 🖼️"):
            with gr.Row():
                with gr.Column(scale=8):
                    sort_dropdown_ic = gr.Dropdown(choices=["Low to High", "High to Low"], label="Sort", value="Low to High")
                with gr.Column(scale=4):
                    ic_callout = gr.HTML()
            ic_table = gr.HTML()
            init_callout, init_table = update_image_classification("Low to High")
            ic_callout.value = init_callout
            ic_table.value = init_table
            sort_dropdown_ic.change(fn=update_image_classification, inputs=sort_dropdown_ic, outputs=[ic_callout, ic_table])
        
        # --- Image Captioning Tab ---
        with gr.TabItem("Image Captioning 📝"):
            with gr.Row():
                with gr.Column(scale=8):
                    sort_dropdown_icap = gr.Dropdown(choices=["Low to High", "High to Low"], label="Sort", value="Low to High")
                with gr.Column(scale=4):
                    icap_callout = gr.HTML()
            icap_table = gr.HTML()
            init_callout, init_table = update_image_captioning("Low to High")
            icap_callout.value = init_callout
            icap_table.value = init_table
            sort_dropdown_icap.change(fn=update_image_captioning, inputs=sort_dropdown_icap, outputs=[icap_callout, icap_table])
        
        # --- Summarization Tab ---
        with gr.TabItem("Summarization 📃"):
            with gr.Row():
                with gr.Column(scale=8):
                    sort_dropdown_sum = gr.Dropdown(choices=["Low to High", "High to Low"], label="Sort", value="Low to High")
                with gr.Column(scale=4):
                    sum_callout = gr.HTML()
            sum_table = gr.HTML()
            init_callout, init_table = update_summarization("Low to High")
            sum_callout.value = init_callout
            sum_table.value = init_table
            sort_dropdown_sum.change(fn=update_summarization, inputs=sort_dropdown_sum, outputs=[sum_callout, sum_table])
        
        # --- Automatic Speech Recognition Tab ---
        with gr.TabItem("Automatic Speech Recognition 💬"):
            with gr.Row():
                with gr.Column(scale=8):
                    sort_dropdown_asr = gr.Dropdown(choices=["Low to High", "High to Low"], label="Sort", value="Low to High")
                with gr.Column(scale=4):
                    asr_callout = gr.HTML()
            asr_table = gr.HTML()
            init_callout, init_table = update_asr("Low to High")
            asr_callout.value = init_callout
            asr_table.value = init_table
            sort_dropdown_asr.change(fn=update_asr, inputs=sort_dropdown_asr, outputs=[asr_callout, asr_table])
        
        # --- Object Detection Tab ---
        with gr.TabItem("Object Detection 🚘"):
            with gr.Row():
                with gr.Column(scale=8):
                    sort_dropdown_od = gr.Dropdown(choices=["Low to High", "High to Low"], label="Sort", value="Low to High")
                with gr.Column(scale=4):
                    od_callout = gr.HTML()
            od_table = gr.HTML()
            init_callout, init_table = update_object_detection("Low to High")
            od_callout.value = init_callout
            od_table.value = init_table
            sort_dropdown_od.change(fn=update_object_detection, inputs=sort_dropdown_od, outputs=[od_callout, od_table])
        
        # --- Sentence Similarity Tab ---
        with gr.TabItem("Sentence Similarity 📚"):
            with gr.Row():
                with gr.Column(scale=8):
                    sort_dropdown_ss = gr.Dropdown(choices=["Low to High", "High to Low"], label="Sort", value="Low to High")
                with gr.Column(scale=4):
                    ss_callout = gr.HTML()
            ss_table = gr.HTML()
            init_callout, init_table = update_sentence_similarity("Low to High")
            ss_callout.value = init_callout
            ss_table.value = init_table
            sort_dropdown_ss.change(fn=update_sentence_similarity, inputs=sort_dropdown_ss, outputs=[ss_callout, ss_table])
        
        # --- Extractive QA Tab ---
        with gr.TabItem("Extractive QA ❔"):
            with gr.Row():
                with gr.Column(scale=8):
                    sort_dropdown_qa = gr.Dropdown(choices=["Low to High", "High to Low"], label="Sort", value="Low to High")
                with gr.Column(scale=4):
                    qa_callout = gr.HTML()
            qa_table = gr.HTML()
            init_callout, init_table = update_extractive_qa("Low to High")
            qa_callout.value = init_callout
            qa_table.value = init_table
            sort_dropdown_qa.change(fn=update_extractive_qa, inputs=sort_dropdown_qa, outputs=[qa_callout, qa_table])
        
        # --- All Tasks Tab ---
        with gr.TabItem("All Tasks 💡"):
            with gr.Row():
                with gr.Column(scale=8):
                    sort_dropdown_all = gr.Dropdown(choices=["Low to High", "High to Low"], label="Sort", value="Low to High")
                with gr.Column(scale=4):
                    all_callout = gr.HTML()
            all_table = gr.HTML()
            init_callout, init_table = update_all_tasks("Low to High")
            all_callout.value = init_callout
            all_table.value = init_table
            sort_dropdown_all.change(fn=update_all_tasks, inputs=sort_dropdown_all, outputs=[all_callout, all_table])
    
    with gr.Accordion("📙 Citation", open=False):
        citation_button = gr.Textbox(
            value=CITATION_BUTTON_TEXT,
            label=CITATION_BUTTON_LABEL,
            elem_id="citation-button",
            lines=10,
            show_copy_button=True,
        )
    gr.Markdown("Last updated: February 2025")

demo.launch()