import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM
from huggingface_hub import login
import torch
import os

# Hugging Face token login
HF_TOKEN = os.getenv("HF_TOKEN")
login(token=HF_TOKEN)

# Define models
MODELS = {
    "atlas-flash-1215": {
        "name": "🦁 Atlas-Flash 1215",
        "sizes": {
            "1.5B": "Spestly/Atlas-Flash-1.5B-Preview",
        },
        "emoji": "🦁",
        "experimental": True,
        "is_vision": False,
        "system_prompt_env": "ATLAS_FLASH_1215",
    },
    "atlas-pro-0403": {
        "name": "🏆 Atlas-Pro 0403",
        "sizes": {
            "1.5B": "Spestly/Atlas-Pro-1.5B-Preview",
        },
        "emoji": "🏆",
        "experimental": True,
        "is_vision": False,
        "system_prompt_env": "ATLAS_PRO_0403",
    },
}

# Load default model
default_model_key = "atlas-pro-0403"
default_size = "1.5B"
default_model = MODELS[default_model_key]["sizes"][default_size]

def load_model(model_name):
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = AutoModelForCausalLM.from_pretrained(
        model_name, 
        torch_dtype=torch.float32, 
        low_cpu_mem_usage=True
    )
    model.eval()
    return tokenizer, model

tokenizer, model = load_model(default_model)

# Generate response function
def generate_response(message, image, history, model_key, model_size, temperature, top_p, max_new_tokens):
    global tokenizer, model
    selected_model = MODELS[model_key]["sizes"][model_size]
    if selected_model != default_model:
        tokenizer, model = load_model(selected_model)
    
    system_prompt_env = MODELS[model_key]["system_prompt_env"]
    system_prompt = os.getenv(system_prompt_env, "You are an advanced AI system. Help the user as best as you can.")
    
    if MODELS[model_key]["is_vision"]:
        image_info = "An image has been provided as input."
        instruction = f"{system_prompt}\n\n### Instruction:\n{message}\n{image_info}\n\n### Response:"
    else:
        instruction = f"{system_prompt}\n\n### Instruction:\n{message}\n\n### Response:"
    
    inputs = tokenizer(instruction, return_tensors="pt")
    
    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_new_tokens=max_new_tokens,
            num_return_sequences=1,
            temperature=temperature,
            top_p=top_p,
            do_sample=True
        )
    
    response = tokenizer.decode(outputs[0], skip_special_tokens=True)
    response = response.split("### Response:")[-1].strip()
    return response

def create_interface():
    with gr.Blocks(title="🌟 Atlas-Pro/Flash/Vision Interface", theme="soft") as iface:
        gr.Markdown("Interact with multiple models like Atlas-Pro, Atlas-Flash, and AtlasV-Flash (Coming Soon!). Upload images for vision models!")

        model_key_selector = gr.Dropdown(
            label="Model",
            choices=list(MODELS.keys()),
            value=default_model_key
        )
        model_size_selector = gr.Dropdown(
            label="Model Size",
            choices=list(MODELS[default_model_key]["sizes"].keys()),
            value=default_size
        )
        image_input = gr.Image(label="Upload Image (if applicable)", type="filepath", visible=False)
        message_input = gr.Textbox(label="Message", placeholder="Type your message here...")
        temperature_slider = gr.Slider(label="Temperature", minimum=0.1, maximum=1.0, value=0.7, step=0.1)
        top_p_slider = gr.Slider(label="Top-p", minimum=0.1, maximum=1.0, value=0.9, step=0.1)
        max_tokens_slider = gr.Slider(label="Max New Tokens", minimum=50, maximum=2000, value=1000, step=50)
        chat_output = gr.Chatbot(label="Chatbot")
        submit_button = gr.Button("Submit")

        def update_components(model_key):
            model_info = MODELS[model_key]
            new_sizes = list(model_info["sizes"].keys())
            return [
                gr.Dropdown(choices=new_sizes, value=new_sizes[0]),
                gr.Image(visible=model_info["is_vision"])
            ]

        model_key_selector.change(
            fn=update_components,
            inputs=model_key_selector,
            outputs=[model_size_selector, image_input]
        )

        submit_button.click(
            fn=generate_response,
            inputs=[
                message_input,
                image_input,
                chat_output,
                model_key_selector,
                model_size_selector,
                temperature_slider,
                top_p_slider,
                max_tokens_slider
            ],
            outputs=chat_output
        )

    return iface

create_interface().launch()