from threading import Thread
import gradio as gr
import random
import torch
import spaces
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    AutoConfig,
    TextIteratorStreamer
)

# Constants for the model and configuration
MODEL_ID = "AstroMLab/AstroSage-8B"
WINDOW_SIZE = 2048
DEVICE = "cuda"

# Load model configuration, tokenizer, and model
config = AutoConfig.from_pretrained(pretrained_model_name_or_path=MODEL_ID)
tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path=MODEL_ID)
model = AutoModelForCausalLM.from_pretrained(
    pretrained_model_name_or_path=MODEL_ID,
    config=config,
    device_map="auto",
    use_safetensors=True,
    trust_remote_code=True,
    load_in_4bit=True,
    torch_dtype=torch.bfloat16
)

# Placeholder responses for when context is empty
GREETING_MESSAGES = [
    "Greetings! I am AstroSage, your guide to the cosmos. What would you like to explore today?",
    "Welcome to our cosmic journey! I am AstroSage. How may I assist you in understanding the universe?",
    "AstroSage here. Ready to explore the mysteries of space and time. How may I be of assistance?",
    "The universe awaits! I'm AstroSage. What astronomical wonders shall we discuss?",
]


def format_message(role: str, content: str) -> str:
    """Format a single message according to Llama-3 chat template."""
    return f"<|start_header_id|>{role}<|end_header_id|>\n\n{content}<|eot_id|>"


def generate_text(prompt: str, history: list, max_new_tokens=512, temperature=0.7, top_p=0.95):
    """
    Generate a response using the transformer model with proper Llama-3 chat formatting.
    """
    # Start with begin_of_text token
    formatted_messages = ["<|begin_of_text|>"]
    
    # Add formatted history
    for msg in history:
        formatted_message = format_message(msg['role'], msg['content'])
        formatted_messages.append(formatted_message)
    
    # Add the current prompt
    formatted_message = format_message('user', prompt)
    formatted_messages.append(formatted_message)
    
    # Add the start of assistant's response
    formatted_messages.append("<|start_header_id|>assistant<|end_header_id|>\n\n")
    
    # Combine all messages
    prompt_with_history = "\n".join(formatted_messages)
    
    # Encode the prompt
    inputs = tokenizer([prompt_with_history], return_tensors="pt", truncation=True).to(DEVICE)
    input_length = inputs["input_ids"].shape[-1]
    max_new_tokens = min(max_new_tokens, WINDOW_SIZE - input_length)

    # Prepare text streamer for live updates
    streamer = TextIteratorStreamer(
        tokenizer=tokenizer,
        timeout=10.0,
        skip_prompt=True,
        skip_special_tokens=True
    )
    
    generation_kwargs = dict(
        **inputs,
        streamer=streamer,
        max_new_tokens=max_new_tokens,
        do_sample=True,
        top_p=top_p,
        temperature=temperature,
    )

    # Generate the response in a separate thread for streaming
    thread = Thread(target=model.generate, kwargs=generation_kwargs)
    thread.start()

    # Collect and return the response
    response = ""
    for new_text in streamer:
        response += new_text
        yield response


def user(user_message, history):
    """
    Add the user's message to the history.
    """
    if history is None:
        history = []
    return "", history + [{"role": "user", "content": user_message}]

@spaces.GPU
def bot(history):
    """
    Generate the bot's response based on the history.
    """
    if not history:
        history = [{"role": "assistant", "content": random.choice(GREETING_MESSAGES)}]
    last_user_message = history[-1]["content"] if history else ""
    response_generator = generate_text(last_user_message, history)
    history.append({"role": "assistant", "content": ""})

    # Stream the response back
    for partial_response in response_generator:
        history[-1]["content"] = partial_response
        yield history


def initial_greeting():
    """
    Return the initial greeting message.
    """
    return [
        {"role": "system","content": "You are AstroSage, an intelligent AI assistant specializing in astronomy, astrophysics, and cosmology. Provide accurate, scientific information while making complex concepts accessible. You're enthusiastic about space exploration and maintain a sense of wonder about the cosmos. Start by introducing yourself."},
        {"role": "assistant", "content": random.choice(GREETING_MESSAGES)}
    ]


# Custom CSS for a space theme
custom_css = """
#component-0 {
    background-color: #1a1a2e;
    border-radius: 15px;
    padding: 20px;
}
.dark {
    background-color: #0f0f1a;
}
.contain {
    max-width: 1200px !important;
}
"""

# Create the Gradio interface
with gr.Blocks(css=custom_css, theme=gr.themes.Soft(primary_hue="indigo", neutral_hue="slate")) as demo:
    gr.Markdown(
        """
        # 🌌 AstroSage: Your Cosmic AI Companion
        
        Welcome to AstroSage, an advanced AI assistant specializing in astronomy, astrophysics, and cosmology. 
        Powered by the AstroSage-Llama-3.1-8B model, I'm here to help you explore the wonders of the universe!
        
        ### What Can I Help You With?
        - 🪐 Explanations of astronomical phenomena
        - 🚀 Space exploration and missions
        - ⭐ Stars, galaxies, and cosmology
        - 🌍 Planetary science and exoplanets
        - 📊 Astrophysics concepts and theories
        - 🔭 Astronomical instruments and observations
        
        Just type your question below and let's embark on a cosmic journey together!
        """
    )
    
    chatbot = gr.Chatbot(
        label="Chat with AstroSage",
        bubble_full_width=False,
        show_label=True,
        height=450,
        type="messages"
    )
    
    with gr.Row():
        msg = gr.Textbox(
            label="Type your message here",
            placeholder="Ask me anything about space and astronomy...",
            scale=9
        )
        clear = gr.Button("Clear Chat", scale=1)
    
    # Example questions for quick start
    gr.Examples(
        examples=[
            "What is a black hole and how does it form?",
            "Can you explain the life cycle of a star?",
            "What are exoplanets and how do we detect them?",
            "Tell me about the James Webb Space Telescope.",
            "What is dark matter and why is it important?"
        ],
        inputs=msg,
        label="Example Questions"
    )
    
    # Set up the message chain with streaming
    msg.submit(
        user,
        [msg, chatbot],
        [msg, chatbot],
        queue=False
    ).then(
        bot,
        chatbot,
        chatbot
    )
    
    # Clear button functionality
    clear.click(lambda: None, None, chatbot, queue=False)
    
    # Initial greeting
    demo.load(initial_greeting, None, chatbot, queue=False)

# Launch the app
if __name__ == "__main__":
    demo.launch()