import gradio as gr
import os
import requests
import threading
from datetime import datetime
from typing import List, Dict, Any

# Get the Hugging Face API key from Spaces secrets
HF_API_KEY = os.getenv("HF_API_KEY")

# Model endpoints configuration
MODEL_ENDPOINTS = {
    "Qwen2.5-72B-Instruct": "https://api-inference.huggingface.co/models/Qwen/Qwen2.5-72B-Instruct",
    "Llama3.3-70B-Instruct": "https://api-inference.huggingface.co/models/meta-llama/Llama-3.3-70B-Instruct",
    "Qwen2.5-Coder-32B-Instruct": "https://api-inference.huggingface.co/models/Qwen/Qwen2.5-Coder-32B-Instruct",
}

def query_model(model_name: str, messages: List[Dict[str, str]]) -> str:
    """Query a single model with the chat history"""
    endpoint = MODEL_ENDPOINTS[model_name]
    headers = {
        "Authorization": f"Bearer {HF_API_KEY}",
        "Content-Type": "application/json"
    }
    
    # Build full conversation history for context
    conversation = "\n".join([f"{msg['role']}: {msg['content']}" for msg in messages])
    
    # Model-specific prompt formatting with full history
    model_prompts = {
        "Qwen2.5-72B-Instruct": (
            f"<|im_start|>system\nCollaborate with other experts. Previous discussion:\n{conversation}<|im_end|>\n"
            "<|im_start|>assistant\nMy analysis:"
        ),
        "Llama3.3-70B-Instruct": (
            "<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n"
            f"Build upon this discussion:\n{conversation}<|eot_id|>\n"
            "<|start_header_id|>assistant<|end_header_id|>\nMy contribution:"
        ),
        "Qwen2.5-Coder-32B-Instruct": (
            f"<|im_start|>system\nTechnical discussion context:\n{conversation}<|im_end|>\n"
            "<|im_start|>assistant\nTechnical perspective:"
        )
    }

    # Model-specific stop sequences
    stop_sequences = {
        "Qwen2.5-72B-Instruct": ["<|im_end|>", "<|endoftext|>"],
        "Llama3.3-70B-Instruct": ["<|eot_id|>", "\nuser:"],
        "Qwen2.5-Coder-32B-Instruct": ["<|im_end|>", "<|endoftext|>"]
    }

    payload = {
        "inputs": model_prompts[model_name],
        "parameters": {
            "max_tokens": 2048,
            "temperature": 0.7,
            "stop_sequences": stop_sequences[model_name],
            "return_full_text": False
        }
    }
    
    try:
        response = requests.post(endpoint, json=payload, headers=headers)
        response.raise_for_status()
        result = response.json()[0]['generated_text']
        # Clean up response formatting
        result = result.split('<|')[0]  # Remove any remaining special tokens
        result = result.replace('**', '').replace('##', '')  # Remove markdown
        result = result.strip()  # Remove leading/trailing whitespace
        return result  # Return complete response
    except Exception as e:
        return f"{model_name} error: {str(e)}"

def respond(message: str, history: List[List[str]], session_id: str) -> str:
    """Handle sequential model responses with session tracking"""
    # Load session history
    session = session_manager.load_session(session_id)
    messages = [{"role": "user", "content": message}]
    
    # Store user message in session
    session["history"].append({
        "timestamp": datetime.now().isoformat(),
        "type": "user",
        "content": message
    })
    
    # Get first model's response
    response1 = query_model("Qwen2.5-Coder-32B-Instruct", messages)
    yield f"**Qwen2.5-Coder-32B-Instruct**:\n{response1}"
    
    # Add first response to context
    messages.append({
        "role": "assistant",
        "content": f"Previous response: {response1}"
    })
    
    # Get second model's response
    response2 = query_model("Qwen2.5-72B-Instruct", messages)
    yield f"**Qwen2.5-72B-Instruct**:\n{response2}"
    
    # Add second response to context
    messages.append({
        "role": "assistant",
        "content": f"Previous responses: {response1}\n{response2}"
    })
    
    # Get final model's response
    response3 = query_model("Llama3.3-70B-Instruct", messages)
    yield f"**Llama3.3-70B-Instruct**:\n{response3}"

# Create the Gradio interface with session management
with gr.Blocks(title="Multi-LLM Collaboration Chat") as demo:
    session_id = gr.State(session_manager.create_session)
    
    with gr.Row():
        gr.Markdown("## Multi-LLM Collaboration Chat")
        new_session_btn = gr.Button("🆕 New Session", variant="secondary")
    
    with gr.Row():
        gr.Markdown("A group chat with Qwen2.5-72B, Llama3.3-70B, and Qwen2.5-Coder-32B")
    
    chat_interface = gr.ChatInterface(
        respond,
        examples=["How can I optimize Python code?", "Explain quantum computing basics"],
        additional_inputs=[session_id]
    )
    
    def create_new_session():
        new_id = session_manager.create_session()
        return new_id, None
    
    new_session_btn.click(
        fn=create_new_session,
        outputs=[session_id, chat_interface.chatbot],
        show_progress=False
    )

if __name__ == "__main__":
    chat_interface.launch(share=True)