import gradio as gr
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig, StoppingCriteriaList
import spaces
import os
import json
from huggingface_hub import login

# Hugging Face authentication
HF_TOKEN = os.getenv('Secrets.HF_TOKEN')
try:
    login(token=HF_TOKEN)
except Exception as e:
    print(f"Error logging in to Hugging Face: {str(e)}")

# File to store model links
MODEL_FILE = "model_links.txt"

def load_model_links():
    """Load model links from file"""
    if not os.path.exists(MODEL_FILE):
        # Create default file with some example models
        with open(MODEL_FILE, "w") as f:
            f.write("meta-llama/Llama-2-7b-chat-hf\n")
    
    with open(MODEL_FILE, "r") as f:
        return [line.strip() for line in f.readlines() if line.strip()]

class ModelManager:
    def __init__(self):
        self.current_model = None
        self.current_tokenizer = None
        self.current_model_name = None
        # Don't initialize CUDA in __init__
        self.device = None
    
    def load_model(self, model_name):
        """Load model and free previous model's memory"""
        if self.current_model is not None:
            del self.current_model
            del self.current_tokenizer
            torch.cuda.empty_cache()
        
        try:
            self.current_tokenizer = AutoTokenizer.from_pretrained(model_name)
            self.current_model = AutoModelForCausalLM.from_pretrained(
                model_name, 
                load_in_4bit=False,
                torch_dtype=torch.bfloat16,
                device_map="auto"  # Let the model decide device mapping
            )
            self.current_model_name = model_name
            return f"Successfully loaded model: {model_name}"
        except Exception as e:
            return f"Error loading model: {str(e)}"

    def generate(self, prompt):
        """Helper method for generation"""
        inputs = self.current_tokenizer(prompt, return_tensors="pt")
        # Let device mapping happen automatically
        return inputs


# Initialize model manager
model_manager = ModelManager()

# Default system message for JSON output
default_system_message = """You are a helpful AI assistant. You must ALWAYS return your response in valid JSON format.
Each response should be formatted as follows:
{
    "response": {
        "main_answer": "Your primary response here",
        "additional_details": "Any additional information or context",
        "confidence": 0.0 to 1.0,
        "tags": ["relevant", "tags", "here"]
    },
    "metadata": {
        "response_type": "type of response",
        "source": "basis of response if applicable"
    }
}
Ensure EVERY response strictly follows this JSON structure."""

@spaces.GPU  # This decorator handles the GPU allocation
def generate_response(model_name, system_instruction, user_input):
    """Generate response with GPU support and JSON formatting"""
    if model_manager.current_model_name != model_name:
        return json.dumps({"error": "Please load the model first using the 'Load Selected Model' button."}, indent=2)
    
    if model_manager.current_model is None:
        return json.dumps({"error": "No model loaded. Please load a model first."}, indent=2)

    prompt = f"""### Instruction:
{system_instruction}
Remember to ALWAYS format your response as valid JSON.
### Input:
{user_input}
### Response:
{{"""

    try:
        inputs = model_manager.generate(prompt)
        
        meta_config = {
            "do_sample": False,
            "temperature": 0.0,
            "max_new_tokens": 512,
            "repetition_penalty": 1.1,
            "use_cache": True,
            "pad_token_id": model_manager.current_tokenizer.eos_token_id,
            "eos_token_id": model_manager.current_tokenizer.eos_token_id
        }
        generation_config = GenerationConfig(**meta_config)

        with torch.no_grad():
            outputs = model_manager.current_model.generate(
                **inputs,
                generation_config=generation_config
            )
            
            decoded_output = model_manager.current_tokenizer.batch_decode(
                outputs, 
                skip_special_tokens=True
            )[0]
            
            assistant_response = decoded_output.split("### Response:")[-1].strip()
            
            try:
                last_brace = assistant_response.rindex('}')
                assistant_response = assistant_response[:last_brace + 1]
                json_response = json.loads(assistant_response)
                return json.dumps(json_response, indent=2)
            except (json.JSONDecodeError, ValueError):
                return json.dumps({
                    "error": "Failed to generate valid JSON",
                    "raw_response": assistant_response
                }, indent=2)
                
    except Exception as e:
        return json.dumps({
            "error": f"Error generating response: {str(e)}",
            "details": "An unexpected error occurred during generation"
        }, indent=2)




# Gradio interface setup
with gr.Blocks() as demo:
    gr.Markdown("# Chat Interface with Model Selection (JSON Output)")

    with gr.Row():
        # Left column for inputs
        with gr.Column():
            model_dropdown = gr.Dropdown(
                choices=load_model_links(),
                label="Select Model",
                info="Choose a model from the list"
            )
            load_button = gr.Button("Load Selected Model")
            model_status = gr.Textbox(label="Model Status")
            
            system_instruction = gr.Textbox(
                value=default_system_message,
                placeholder="Enter system instruction here...",
                label="System Instruction",
                lines=3
            )
            user_input = gr.Textbox(
                placeholder="Type your message here...",
                label="Your Message",
                lines=3
            )
            submit_btn = gr.Button("Submit")

        # Right column for bot response
        with gr.Column():
            response_display = gr.Textbox(
                label="Bot Response (JSON)", 
                interactive=False, 
                placeholder="Response will appear here in JSON format.",
                lines=10
            )

    # Event handlers
    load_button.click(
        fn=model_manager.load_model,
        inputs=[model_dropdown],
        outputs=[model_status]
    )
    
    submit_btn.click(
        fn=generate_response,
        inputs=[model_dropdown, system_instruction, user_input],
        outputs=[response_display]
    )

# Launch the app
demo.launch()