import gradio as gr
import sys
import os
import random
import llm_blender
import descriptions
from datasets import load_dataset
from llm_blender.blender.blender_utils import get_topk_candidates_from_ranks
from typing import List


MAX_BASE_LLM_NUM = 20
MIN_BASE_LLM_NUM = 3
SOURCE_MAX_LENGTH = 256
DEFAULT_SOURCE_MAX_LENGTH = 128
CANDIDATE_MAX_LENGTH = 256
DEFAULT_CANDIDATE_MAX_LENGTH = 128
FUSER_MAX_NEW_TOKENS = 512
DEFAULT_FUSER_MAX_NEW_TOKENS = 256


# MIX-INSTRUCT
EXAMPLES_DATASET = load_dataset("llm-blender/mix-instruct", split='validation', streaming=True)
SHUFFLED_EXAMPLES_DATASET = EXAMPLES_DATASET.shuffle(seed=42)
MIX_INSTRUCT_EXAMPLES = []
CANDIDATE_MAP = {}
for i, example in enumerate(SHUFFLED_EXAMPLES_DATASET.take(100)):
    MIX_INSTRUCT_EXAMPLES.append([
        example['instruction'],
        example['input'],
    ])
    CANDIDATE_MAP[example['instruction']+example['input']] = example['candidates']
    if i > 100:
        break

# HHH ALIGNMENT
HHH_EXAMPLES = []
subsets = ['harmless', 'helpful', 'honest', 'other']
random.seed(42)
for subset in subsets:
    dataset = load_dataset("HuggingFaceH4/hhh_alignment", subset)
    for example in dataset['test']:
        if random.random() < 0.5:
            HHH_EXAMPLES.append([
                subset,
                example['input'],
                example['targets']['choices'][0],
                example['targets']['choices'][1],
                "Response 1" if example['targets']['labels'][0] == 1 else "Response 2",
            ])
        else:
            HHH_EXAMPLES.append([
                subset,
                example['input'],
                example['targets']['choices'][1],
                example['targets']['choices'][0],
                "Response 2" if example['targets']['labels'][0] == 1 else "Response 1",
            ])
def get_hhh_examples(subset, instruction, response1, response2, dummy_text):
    return instruction, response1, response2

# MT_BENCH_HUMAN_JUDGMENTS
MT_BENCH_HUMAN_JUDGE_EXAMPLES = []
dataset = load_dataset("lmsys/mt_bench_human_judgments")
for example in dataset['human']:
    if example['turn'] != 1:
        continue
    MT_BENCH_HUMAN_JUDGE_EXAMPLES.append([
        example['model_a'],
        example['model_b'],
        str(example['conversation_a']),
        str(example['conversation_b']),
        "Model A" if example['winner'] == 'model_a' else "Model B",
    ])
def get_mt_bench_human_judge_examples(model_a, model_b, conversation_a, conversation_b, dummy_text):
    chat_history_a = []
    chat_history_b = []
    conversation_a = eval(conversation_a)
    conversation_b = eval(conversation_b)
    for i in range(0, len(conversation_a), 2):
        chat_history_a.append((conversation_a[i]['content'], conversation_a[i+1]['content']))
        assert conversation_a[i]['role'] == 'user' and conversation_a[i+1]['role'] == 'assistant'
    for i in range(0, len(conversation_b), 2):
        chat_history_b.append((conversation_b[i]['content'], conversation_b[i+1]['content']))
        assert conversation_b[i]['role'] == 'user' and conversation_b[i+1]['role'] == 'assistant'
    return chat_history_a, chat_history_b


blender = llm_blender.Blender()
blender.loadranker("llm-blender/PairRM")
# blender.loadfuser("llm-blender/gen_fuser_3b")

def update_base_llms_num(k, llm_outputs):
    k = int(k)
    return [gr.Dropdown(choices=[f"LLM-{i+1}" for i in range(k)], 
        value=f"LLM-1" if k >= 1 else "", visible=True),
        {f"LLM-{i+1}": llm_outputs.get(f"LLM-{i+1}", "") for i in range(k)}]
    

def display_llm_output(llm_outputs, selected_base_llm_name):
    return gr.Textbox(value=llm_outputs.get(selected_base_llm_name, ""), 
        label=selected_base_llm_name + " (Click Save to save current content)", 
        placeholder=f"Enter {selected_base_llm_name} output here", show_label=True)

def save_llm_output(selected_base_llm_name, selected_base_llm_output, llm_outputs):
    llm_outputs({selected_base_llm_name: selected_base_llm_output})
    return llm_outputs

def get_preprocess_examples(inst, input):
    # get the num_of_base_llms
    candidates = CANDIDATE_MAP[inst+input]
    num_candiates = len(candidates)
    dummy_text = inst+input
    return inst, input, num_candiates, dummy_text

def update_base_llm_dropdown_along_examples(inst, input):
    candidates = CANDIDATE_MAP[inst+input]
    ex_llm_outputs = {f"LLM-{i+1}": candidates[i]['text'] for i in range(len(candidates))}
    k = len(candidates)
    return ex_llm_outputs, "", "", \
        gr.Dropdown(choices=[f"LLM-{i+1}" for i in range(k)], value=f"LLM-1" if k >= 1 else "", visible=True)
    
def check_save_ranker_inputs(inst, input, llm_outputs, blender_config):
    if not inst and not input:
        raise gr.Error("Please enter instruction or input context")
    
    if not all([x for x in llm_outputs.values()]):
        empty_llm_names = [llm_name for llm_name, llm_output in llm_outputs.items() if not llm_output]
        raise gr.Error("Please enter base LLM outputs for LLMs: {}").format(empty_llm_names)
    return {
        "inst": inst,
        "input": input,
        "candidates": list(llm_outputs.values()),
    }

def check_fuser_inputs(blender_state, blender_config, ranks):
    if "candidates" not in blender_state or len(ranks)==0:
        raise gr.Error("Please rank LLM outputs first")
    if not (blender_state.get("inst", None) or blender_state.get("input", None)):
        raise gr.Error("Please enter instruction or input context")
    
    return 

def llms_rank(inst, input, llm_outputs, blender_config):
    candidates = list(llm_outputs.values())
    rank_params = {
        "source_max_length": blender_config['source_max_length'],
        "candidate_max_length": blender_config['candidate_max_length'],
    }
    ranks = blender.rank(instructions=[inst], inputs=[input], candidates=[candidates])[0]
    return [ranks, ",  ".join([f"LLM-{i+1}: {rank}" for i, rank in enumerate(ranks)])]


def llms_fuse(blender_state, blender_config, ranks):
    inst = blender_state['inst']
    input = blender_state['input']
    candidates = blender_state['candidates']
    top_k_for_fuser = blender_config['top_k_for_fuser']
    fuse_params = blender_config.copy()
    fuse_params.pop("top_k_for_fuser")
    fuse_params.pop("source_max_length")
    fuse_params['no_repeat_ngram_size'] = 3
    top_k_candidates = get_topk_candidates_from_ranks([ranks], [candidates], top_k=top_k_for_fuser)[0]
    fuser_outputs = blender.fuse(instructions=[inst], inputs=[input], candidates=[top_k_candidates], **fuse_params, batch_size=1)[0]
    return [fuser_outputs, fuser_outputs]

def display_fuser_output(fuser_output):
    return fuser_output

        
with gr.Blocks(theme='ParityError/Anime') as demo:

    
    
    with gr.Tab("PairRM"):
        # PairRM interface
        with gr.Row():
            gr.Markdown(descriptions.PairRM_OVERALL_DESC)
            gr.Image("https://yuchenlin.xyz/LLM-Blender/pairranker.png")
        
        with gr.Tab("Compare two responses"):
            instruction = gr.Textbox(lines=1, label="Instruction", placeholder="Enter instruction here", show_label=True)
            with gr.Row():
                response1 = gr.Textbox(lines=4, label="Response 1", placeholder="Enter response 1 here", show_label=True)
                response2 = gr.Textbox(lines=4, label="Response 2", placeholder="Enter response 2 here", show_label=True)
            with gr.Row():
                compare_button = gr.Button('Compare', variant='primary')
                clear_button = gr.Button('Clear', variant='primary')
            with gr.Row():
                compare_result = gr.Textbox(lines=1, label="Compare Result", placeholder="", show_label=True)
                compare_result_prob = gr.Textbox(lines=1, label="PairRM Confidence", placeholder="", show_label=True)
            
            def compare_fn(inst, response1, response2):
                if not inst:
                    raise gr.Error("Please enter instruction")
                if not response1 or not response2:
                    raise gr.Error("Please enter response 1 and response 2")
                comparison_results = blender.compare([inst], [response1], [response2], return_logits=True)
                logit = comparison_results[0]
                if logit > 0:
                    result = "Response 1 is better than Response 2"
                    prob = f"Confidence: {round(logit, 2)}"
                elif logit < 0:
                    result = "Response 2 is better than Response 1"
                    prob = f"Cofidence: {round(abs(logit), 2)}"
                else:
                    result = "Response 1 and Response 2 are equally good"
                    prob = f"No confidence for tie"
                    
                return [result, prob]
            compare_button.click(
                fn=compare_fn,
                inputs=[instruction, response1, response2],
                outputs=[compare_result, compare_result_prob],
            )
            clear_button.click(
                fn=lambda: ["", "", "", "", ""],
                inputs=[],
                outputs=[instruction, response1, response2, compare_result, compare_result_prob],
            )
            
            hhh_dummy_textbox1 = gr.Textbox(lines=1, label="subset", placeholder="", show_label=False, visible=False)
            hhh_dummy_textbox2 = gr.Textbox(lines=1, label="Better Response", placeholder="", show_label=False, visible=False)
            gr.Markdown("## Examples from [HuggingFaceH4/hhh_alignment](https://huggingface.co/datasets/HuggingFaceH4/hhh_alignment)")
            gr.Examples(
                HHH_EXAMPLES,
                fn=get_hhh_examples,
                cache_examples=True,
                examples_per_page=5,
                inputs=[hhh_dummy_textbox1, instruction, response1, response2, hhh_dummy_textbox2],
                outputs=[instruction, response1, response2],
            )
                
            
        with gr.Tab("Compare assistant's response in two multi-turn conversations"):
            
            gr.Markdown("NOTE: Comparison of two conversations is based on that the user query in each turn is the same of two conversations.")
            def append_message(message, chat_history):
                if not message:
                    return "", chat_history
                if len(chat_history) == 0:
                    chat_history.append((message, "(Please enter your bot response)"))
                else:
                    if chat_history[-1][1] == "(Please enter your bot response)":
                        chat_history[-1] = (chat_history[-1][0], message)
                    else:
                        chat_history.append((message, "(Please enter your bot response)"))
                return "", chat_history
            with gr.Row():
                with gr.Column():
                    gr.Markdown("### Conversation A")
                    chatbot1 = gr.Chatbot()
                    msg1 = gr.Textbox(lines=1, label="Enter Chat history for Conversation A", placeholder="Enter your message here", show_label=True)
                    clear1 = gr.ClearButton([msg1, chatbot1])
                    msg1.submit(append_message, [msg1, chatbot1], [msg1, chatbot1])
                with gr.Column():    
                    gr.Markdown("### Conversation B")
                    chatbot2 = gr.Chatbot()
                    msg2 = gr.Textbox(lines=1, label="Enter Chat history for Conversation B", placeholder="Enter your message here", show_label=True)
                    clear2 = gr.ClearButton([msg2, chatbot2])
                    msg2.submit(append_message, [msg2, chatbot2], [msg2, chatbot2])
            with gr.Row():
                compare_button = gr.Button('Compare', variant='primary')
            with gr.Row():
                compare_result = gr.Textbox(lines=1, label="Compare Result", placeholder="", show_label=True)
                compare_result_prob = gr.Textbox(lines=1, label="PairRM Confidence", placeholder="", show_label=True)
            
            def compare_conv_fn(chat_history1, chat_history2):
                if len(chat_history1) == 0 or len(chat_history2) == 0:
                    raise gr.Error("Please enter chat history for both conversations")
                assert chat_history1[-1][1] != "(Please enter your bot response)" \
                    and chat_history2[-1][1] != "(Please enter your bot response)", \
                    "Please complete chat history for both conversations"
                chat1_messages = []
                for item in chat_history1:
                    chat1_messages.append({
                        "role": "USER",
                        "content": item[0],
                    })
                    chat1_messages.append({
                        "role": "ASSISTANT",
                        "content": item[1],
                    })
                chat2_messages = []
                for item in chat_history2:
                    chat2_messages.append({
                        "role": "USER",
                        "content": item[0],
                    })
                    chat2_messages.append({
                        "role": "ASSISTANT",
                        "content": item[1],
                    })
                                          
                comparison_results = blender.compare_conversations([chat1_messages], [chat2_messages], return_logits=True)
                logit = comparison_results[0]
                if logit > 0:
                    result = "Assistant's response in Conversation A is better than Conversation B"
                    prob = f"Confidence: {round(logit, 2)}"
                elif logit < 0:
                    result = "Assistant's response in Conversation B is better than Conversation A"
                    prob = f"Cofidence: {round(abs(logit), 2)}"
                else:
                    result = "Assistant's response in Conversation A and Conversation B are equally good"
                    prob = f"No confidence for tie"
                    
                return [result, prob]

            compare_button.click(
                fn=compare_conv_fn,
                inputs=[chatbot1, chatbot2],
                outputs=[compare_result, compare_result_prob],
            )
            
            model_a_dummy_textbox = gr.Textbox(lines=1, label="Model A", placeholder="", show_label=False, visible=False)
            model_b_dummy_textbox = gr.Textbox(lines=1, label="Model B", placeholder="", show_label=False, visible=False)
            winner_dummy_textbox = gr.Textbox(lines=1, label="Better Model in conversation", placeholder="", show_label=False, visible=False)
            chatbot1_dummy_textbox = gr.Textbox(lines=1, label="Conversation A", placeholder="", show_label=False, visible=False)
            chatbot2_dummy_textbox = gr.Textbox(lines=1, label="Conversation B", placeholder="", show_label=False, visible=False)
            gr.Markdown("## Examples from [lmsys/mt_bench_human_judgments](https://huggingface.co/datasets/lmsys/mt_bench_human_judgments)")
            gr.Examples(
                MT_BENCH_HUMAN_JUDGE_EXAMPLES,
                fn=get_mt_bench_human_judge_examples,
                cache_examples=True,
                examples_per_page=5,
                inputs=[model_a_dummy_textbox, model_b_dummy_textbox, chatbot1_dummy_textbox, chatbot2_dummy_textbox, winner_dummy_textbox],
                outputs=[chatbot1, chatbot2],
            )
    
    
    # with gr.Tab("LLM-Blender"):
    #     # llm-blender interface 
    #     with gr.Row():
    #         gr.Markdown(descriptions.LLM_BLENDER_OVERALL_DESC)
    #         gr.Image("https://github.com/yuchenlin/LLM-Blender/blob/main/docs/llm_blender.png?raw=true", height=300)
    #     gr.Markdown("## Input and Base LLMs")
    #     with gr.Row():
    #         with gr.Column():
    #             inst_textbox = gr.Textbox(lines=1, label="Instruction", placeholder="Enter instruction here", show_label=True)
    #             input_textbox = gr.Textbox(lines=4, label="Input Context", placeholder="Enter input context here", show_label=True)
    #         with gr.Column():
    #             saved_llm_outputs = gr.State(value={})
    #             with gr.Group():
    #                 selected_base_llm_name_dropdown = gr.Dropdown(label="Base LLM",
    #                     choices=[f"LLM-{i+1}" for i in range(MIN_BASE_LLM_NUM)], value="LLM-1", show_label=True)
    #                 selected_base_llm_output = gr.Textbox(lines=4, label="LLM-1 (Click Save to save current content)",
    #                     placeholder="Enter LLM-1 output here", show_label=True)
    #             with gr.Row():
    #                 base_llm_outputs_save_button = gr.Button('Save', variant='primary')
                    
    #                 base_llm_outputs_clear_single_button = gr.Button('Clear Single', variant='primary')
                    
    #                 base_llm_outputs_clear_all_button = gr.Button('Clear All', variant='primary')
    #             base_llms_num = gr.Slider(
    #                     label='Number of base llms',
    #                     minimum=MIN_BASE_LLM_NUM,
    #                     maximum=MAX_BASE_LLM_NUM,
    #                     step=1,
    #                     value=MIN_BASE_LLM_NUM,
    #                 )
        
    #     blender_state = gr.State(value={})
    #     saved_rank_outputs = gr.State(value=[])
    #     saved_fuse_outputs = gr.State(value=[])
    #     gr.Markdown("## Blender Outputs")
    #     with gr.Group():
    #         rank_outputs = gr.Textbox(lines=1, label="Ranking outputs", placeholder="Ranking outputs", show_label=True)
    #         fuser_outputs = gr.Textbox(lines=4, label="Fusing outputs", placeholder="Fusing outputs", show_label=True)
    #     with gr.Row():
    #         rank_button = gr.Button('Rank LLM Outputs', variant='primary')
    #         fuse_button = gr.Button('Fuse Top-K ranked outputs', variant='primary')
    #         clear_button = gr.Button('Clear Blender Outputs', variant='primary')
    #     blender_config = gr.State(value={
    #         "source_max_length": DEFAULT_SOURCE_MAX_LENGTH,
    #         "candidate_max_length": DEFAULT_CANDIDATE_MAX_LENGTH,
    #         "top_k_for_fuser": 3,
    #         "max_new_tokens": DEFAULT_FUSER_MAX_NEW_TOKENS,
    #         "temperature": 0.7,
    #         "top_p": 1.0,
    #     })
            
    #     with gr.Accordion(label='Advanced options', open=False):
    #         source_max_length = gr.Slider(
    #             label='Max length of Instruction + Input',
    #             minimum=1,
    #             maximum=SOURCE_MAX_LENGTH,
    #             step=1,
    #             value=DEFAULT_SOURCE_MAX_LENGTH,
    #         )
    #         candidate_max_length = gr.Slider(
    #             label='Max length of LLM-Output Candidate',
    #             minimum=1,
    #             maximum=CANDIDATE_MAX_LENGTH,
    #             step=1,
    #             value=DEFAULT_CANDIDATE_MAX_LENGTH,
    #         )
    #         top_k_for_fuser = gr.Slider(
    #             label='Top-k ranked candidates to fuse',
    #             minimum=1,
    #             maximum=3,
    #             step=1,
    #             value=3,
    #         )
    #         max_new_tokens = gr.Slider(
    #             label='Max new tokens fuser can generate',
    #             minimum=1,
    #             maximum=FUSER_MAX_NEW_TOKENS,
    #             step=1,
    #             value=DEFAULT_FUSER_MAX_NEW_TOKENS,
    #         )
    #         temperature = gr.Slider(
    #             label='Temperature of fuser generation',
    #             minimum=0.1,
    #             maximum=2.0,
    #             step=0.1,
    #             value=0.7,
    #         )
    #         top_p = gr.Slider(
    #             label='Top-p of fuser generation',
    #             minimum=0.05,
    #             maximum=1.0,
    #             step=0.05,
    #             value=1.0,
    #         )
        
    #     examples_dummy_textbox = gr.Textbox(lines=1, label="", placeholder="", show_label=False, visible=False)     
    #     batch_examples = gr.Examples(
    #         examples=MIX_INSTRUCT_EXAMPLES,
    #         fn=get_preprocess_examples,
    #         cache_examples=True,
    #         examples_per_page=5,
    #         inputs=[inst_textbox, input_textbox],
    #         outputs=[inst_textbox, input_textbox, base_llms_num, examples_dummy_textbox],
    #     )
            
    #     base_llms_num.input(
    #         fn=update_base_llms_num,
    #         inputs=[base_llms_num, saved_llm_outputs],
    #         outputs=[selected_base_llm_name_dropdown, saved_llm_outputs],
    #     )
        
    #     examples_dummy_textbox.change(
    #         fn=update_base_llm_dropdown_along_examples,
    #         inputs=[inst_textbox, input_textbox],
    #         outputs=[saved_llm_outputs, rank_outputs, fuser_outputs, selected_base_llm_name_dropdown],
    #     ).then(
    #         fn=display_llm_output,
    #         inputs=[saved_llm_outputs, selected_base_llm_name_dropdown],
    #         outputs=selected_base_llm_output,
    #     )
        
    #     selected_base_llm_name_dropdown.change(
    #         fn=display_llm_output,
    #         inputs=[saved_llm_outputs, selected_base_llm_name_dropdown],
    #         outputs=selected_base_llm_output,
    #     )
        
    #     base_llm_outputs_save_button.click(
    #         fn=save_llm_output,
    #         inputs=[selected_base_llm_name_dropdown, selected_base_llm_output, saved_llm_outputs],
    #         outputs=saved_llm_outputs,
    #     )
    #     base_llm_outputs_clear_all_button.click(
    #         fn=lambda: [{}, ""],
    #         inputs=[],
    #         outputs=[saved_llm_outputs, selected_base_llm_output],
    #     )
    #     base_llm_outputs_clear_single_button.click(
    #         fn=lambda: "",
    #         inputs=[],
    #         outputs=selected_base_llm_output,
    #     )
            

    #     rank_button.click(
    #         fn=check_save_ranker_inputs,
    #         inputs=[inst_textbox, input_textbox, saved_llm_outputs, blender_config],
    #         outputs=blender_state,
    #     ).success(
    #         fn=llms_rank,
    #         inputs=[inst_textbox, input_textbox, saved_llm_outputs, blender_config],
    #         outputs=[saved_rank_outputs, rank_outputs],
    #     )
        
    #     fuse_button.click(
    #         fn=check_fuser_inputs,
    #         inputs=[blender_state, blender_config, saved_rank_outputs],
    #         outputs=[],
    #     ).success(
    #         fn=llms_fuse,
    #         inputs=[blender_state, blender_config, saved_rank_outputs],
    #         outputs=[saved_fuse_outputs, fuser_outputs],
    #     )
        
    #     clear_button.click(
    #         fn=lambda: ["", "", {}, []],
    #         inputs=[],
    #         outputs=[rank_outputs, fuser_outputs, blender_state, saved_rank_outputs],
    #     )
        
    #     # update blender config
    #     source_max_length.change(
    #         fn=lambda x, y: y.update({"source_max_length": x}) or y,
    #         inputs=[source_max_length, blender_config],
    #         outputs=blender_config,
    #     )
    #     candidate_max_length.change(
    #         fn=lambda x, y: y.update({"candidate_max_length": x}) or y,
    #         inputs=[candidate_max_length, blender_config],
    #         outputs=blender_config,
    #     )
    #     top_k_for_fuser.change(
    #         fn=lambda x, y: y.update({"top_k_for_fuser": x}) or y,
    #         inputs=[top_k_for_fuser, blender_config],
    #         outputs=blender_config,
    #     )
    #     max_new_tokens.change(
    #         fn=lambda x, y: y.update({"max_new_tokens": x}) or y,
    #         inputs=[max_new_tokens, blender_config],
    #         outputs=blender_config,
    #     )
    #     temperature.change(
    #         fn=lambda x, y: y.update({"temperature": x}) or y,
    #         inputs=[temperature, blender_config],
    #         outputs=blender_config,
    #     )
    #     top_p.change(
    #         fn=lambda x, y: y.update({"top_p": x}) or y,
    #         inputs=[top_p, blender_config],
    #         outputs=blender_config,
    #     )
        
    

    gr.Markdown(descriptions.CITATION)
demo.queue(max_size=20).launch()