import argparse
import logging
from threading import Thread

import time
import torch
import gradio as gr
import spaces
from concept_guidance.chat_template import DEFAULT_CHAT_TEMPLATE
from concept_guidance.patching import patch_model, load_weights
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline, TextIteratorStreamer, Conversation

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# device = torch.device("cpu")
device = torch.device("cuda")

# comment in/out the models you want to use
# RAM requirements: ~16GB x #models (+ ~4GB overhead)
# VRAM requirements: ~16GB
# if using int8: ~8GB VRAM x #models, low RAM requirements
MODEL_CONFIGS = {
    "Llama-2-7b-chat-hf": {
        "identifier": "meta-llama/Llama-2-7b-chat-hf",
        "dtype": torch.float16 if device.type == "cuda" else torch.float32,
        "load_in_8bit": False,
        "guidance_interval": [-16.0, 16.0],
        "default_guidance_scale": 8.0,
        "min_guidance_layer": 16,
        "max_guidance_layer": 32,
        "default_concept": "humor",
        "concepts": ["humor", "creativity", "quality", "truthfulness", "compliance"],
    },
    # "Mistral-7B-Instruct-v0.1": {
    #     "identifier": "mistralai/Mistral-7B-Instruct-v0.1",
    #     "dtype": torch.bfloat16 if device.type == "cuda" else torch.float32,
    #     "load_in_8bit": False,
    #     "guidance_interval": [-128.0, 128.0],
    #     "default_guidance_scale": 48.0,
    #     "min_guidance_layer": 8,
    #     "max_guidance_layer": 32,
    #     "default_concept": "humor",
    #     "concepts": ["humor", "creativity", "quality", "truthfulness", "compliance"],
    # },
}

def load_concept_vectors(model, concepts):
    return {concept: load_weights(f"trained_concepts/{model}/{concept}.safetensors") for concept in concepts}

def load_model(model_name):
    config = MODEL_CONFIGS[model_name]
    model = AutoModelForCausalLM.from_pretrained(config["identifier"], torch_dtype=config["dtype"], load_in_8bit=config["load_in_8bit"])
    tokenizer = AutoTokenizer.from_pretrained(config["identifier"])
    if tokenizer.chat_template is None:
        tokenizer.chat_template = DEFAULT_CHAT_TEMPLATE
    return model, tokenizer

CONCEPTS = ["humor", "creativity", "quality", "truthfulness", "compliance"]
CONCEPT_VECTORS = {model_name: load_concept_vectors(model_name, CONCEPTS) for model_name in MODEL_CONFIGS}
MODELS = {model_name: load_model(model_name) for model_name in MODEL_CONFIGS}


def history_to_conversation(history):
    conversation = Conversation()
    for prompt, completion in history:
        conversation.add_message({"role": "user", "content": prompt})
        if completion is not None:
            conversation.add_message({"role": "assistant", "content": completion})
    return conversation



def set_defaults(model_name):
    config = MODEL_CONFIGS[model_name]
    return (
        model_name,
        gr.update(choices=config["concepts"], value=config["concepts"][0]),
        gr.update(minimum=config["guidance_interval"][0], maximum=config["guidance_interval"][1], value=config["default_guidance_scale"]),
        gr.update(value=config["min_guidance_layer"]),
        gr.update(value=config["max_guidance_layer"]),
    )

def add_user_prompt(user_message, history):
    if history is None:
        history = []
    history.append([user_message, None])
    return history

@spaces.GPU
@torch.no_grad()
def generate_completion(
    history,
    model_name,
    concept,
    guidance_scale=4.0,
    min_guidance_layer=16,
    max_guidance_layer=32,
    temperature=0.0,
    repetition_penalty=1.2,
    length_penalty=1.2,
):
    start_time = time.time()
    logger.info(f" --- Starting completion ({model_name}, {concept=}, {guidance_scale=}, {min_guidance_layer=}, {temperature=})") 
    logger.info(" User: " + repr(history[-1][0]))
    
    # move all other models to CPU
    for name, (model, _) in MODELS.items():
        if name != model_name:
            config = MODEL_CONFIGS[name]
            if not config["load_in_8bit"]:
                model.to("cpu")
    torch.cuda.empty_cache()
    # load the model
    config = MODEL_CONFIGS[model_name]
    model, tokenizer = MODELS[model_name]
    if not config["load_in_8bit"]:
        model.to(device, non_blocking=True)

    concept_vector = CONCEPT_VECTORS[model_name][concept]
    guidance_layers = list(range(int(min_guidance_layer) - 1, int(max_guidance_layer)))
    patch_model(model, concept_vector, guidance_scale=guidance_scale, guidance_layers=guidance_layers)
    pipe = pipeline("conversational", model=model, tokenizer=tokenizer, device=(device if not config["load_in_8bit"] else None))
    
    conversation = history_to_conversation(history)
    streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)

    generation_kwargs = dict(
        max_new_tokens=1024,
        repetition_penalty=repetition_penalty,
        length_penalty=length_penalty,
        streamer=streamer,
        temperature=temperature,
        do_sample=(temperature > 0)
    )
    thread = Thread(target=pipe, args=(conversation,), kwargs=generation_kwargs, daemon=True)
    thread.start()

    history[-1][1] = ""
    for token in streamer:
        history[-1][1] += token
        yield history
    logger.info(" Assistant: " + repr(history[-1][1]))
    
    time_taken = time.time() - start_time
    logger.info(f" --- Completed (took {time_taken:.1f}s)")
    return history


class ConceptGuidanceUI:
    def __init__(self):
        model_names = list(MODEL_CONFIGS.keys())
        default_model = model_names[0]
        default_config = MODEL_CONFIGS[default_model]
        default_concepts = default_config["concepts"]
        default_concept = default_config["default_concept"]

        saved_input = gr.State("")

        with gr.Row(elem_id="concept-guidance-container"):
            with gr.Column(scale=1, min_width=256):
                model_dropdown = gr.Dropdown(model_names, value=default_model, label="Model")
                concept_dropdown = gr.Dropdown(default_concepts, value=default_concept, label="Concept")
                guidance_scale = gr.Slider(*default_config["guidance_interval"], value=default_config["default_guidance_scale"], label="Guidance Scale")
                min_guidance_layer = gr.Slider(1.0, 32.0, value=16.0, step=1.0, label="First Guidance Layer")
                max_guidance_layer = gr.Slider(1.0, 32.0, value=32.0, step=1.0, label="Last Guidance Layer")
                temperature = gr.Slider(0.0, 1.0, value=0.0, step=0.01, label="Temperature")
                repetition_penalty = gr.Slider(1.0, 2.0, value=1.2, step=0.01, label="Repetition Penalty")
                length_penalty = gr.Slider(0.0, 2.0, value=1.2, step=0.01, label="Length Penalty")

            with gr.Column(scale=3, min_width=512):
                chatbot = gr.Chatbot(scale=1, height=200)

                with gr.Row():
                    self.retry_btn = gr.Button("🔄 Retry", size="sm")
                    self.undo_btn = gr.Button("↩ī¸ Undo", size="sm")
                    self.clear_btn = gr.Button("🗑ī¸ Clear", size="sm")
                
                with gr.Group():
                    with gr.Row():
                        prompt_field = gr.Textbox(placeholder="Type a message...", show_label=False, label="Message", scale=7, container=False)
                        self.submit_btn = gr.Button("Submit", variant="primary", scale=1, min_width=150)
                        self.stop_btn = gr.Button("Stop", variant="secondary", scale=1, min_width=150, visible=False)

        generation_args = [
            model_dropdown,
            concept_dropdown,
            guidance_scale,
            min_guidance_layer,
            max_guidance_layer,
            temperature,
            repetition_penalty,
            length_penalty,
        ]

        model_dropdown.change(set_defaults, [model_dropdown], [model_dropdown, concept_dropdown, guidance_scale, min_guidance_layer, max_guidance_layer], queue=False)

        submit_triggers = [prompt_field.submit, self.submit_btn.click]
        submit_event = gr.on(
            submit_triggers, self.clear_and_save_input, [prompt_field], [prompt_field, saved_input], queue=False
        ).then(
            add_user_prompt, [saved_input, chatbot], [chatbot], queue=False
        ).then(
            generate_completion,
            [chatbot] + generation_args,
            [chatbot],
            concurrency_limit=1,
        )
        self.setup_stop_events(submit_triggers, submit_event)

        retry_triggers = [self.retry_btn.click]
        retry_event = gr.on(
            retry_triggers, self.delete_prev_message, [chatbot], [chatbot, saved_input], queue=False
        ).then(
            add_user_prompt, [saved_input, chatbot], [chatbot], queue=False
        ).then(
            generate_completion,
            [chatbot] + generation_args,
            [chatbot],
            concurrency_limit=1,
        )
        self.setup_stop_events(retry_triggers, retry_event)

        self.undo_btn.click(
            self.delete_prev_message, [chatbot], [chatbot, saved_input], queue=False
        ).then(
            lambda x: x, [saved_input], [prompt_field]
        )
        self.clear_btn.click(lambda: [None, None], None, [chatbot, saved_input], queue=False)

    def clear_and_save_input(self, message):
        return "", message
    
    def delete_prev_message(self, history):
        message, _ = history.pop()
        return history, message or ""

    def setup_stop_events(self, event_triggers, event_to_cancel):
        if self.submit_btn:
            for event_trigger in event_triggers:
                event_trigger(
                    lambda: (
                        gr.Button(visible=False),
                        gr.Button(visible=True),
                    ),
                    None,
                    [self.submit_btn, self.stop_btn],
                    show_api=False,
                    queue=False,
                )
            event_to_cancel.then(
                lambda: (gr.Button(visible=True), gr.Button(visible=False)),
                None,
                [self.submit_btn, self.stop_btn],
                show_api=False,
                queue=False,
            )

        self.stop_btn.click(
            None,
            None,
            None,
            cancels=event_to_cancel,
            show_api=False,
        )

css = """
#concept-guidance-container {
    flex-grow: 1;
}
""".strip()

with gr.Blocks(title="Concept Guidance", fill_height=True, css=css) as demo:
    ConceptGuidanceUI()

demo.queue()
if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--share", action="store_true")
    args = parser.parse_args()
    demo.launch(share=args.share)