# app_hybrid_llm.py
import os
import re
import numpy as np
import faiss
import gradio as gr
import openai
from openai import OpenAI
from langchain.text_splitter import CharacterTextSplitter
from sentence_transformers import SentenceTransformer


DARTMOUTH_CHAT_API_KEY = os.getenv('DARTMOUTH_CHAT_API_KEY')
if DARTMOUTH_CHAT_API_KEY is None:
    raise ValueError("DARTMOUTH_CHAT_API_KEY not set.")

MODEL = "openai.gpt-4o-2024-08-06"

client = OpenAI(
    base_url="https://chat.dartmouth.edu/api",  # Replace with your endpoint URL
    api_key=DARTMOUTH_CHAT_API_KEY,  # Replace with your API key, if required
)

# --- Load and Prepare Data ---
with open("gen_agents.txt", "r", encoding="utf-8") as f:
    full_text = f.read()

text_splitter = CharacterTextSplitter(separator="\n\n", chunk_size=512, chunk_overlap=20)
docs = text_splitter.create_documents([full_text])
passages = [doc.page_content for doc in docs]

embedder = SentenceTransformer('all-MiniLM-L6-v2')
passage_embeddings = embedder.encode(passages, convert_to_tensor=False, show_progress_bar=True)
passage_embeddings = np.array(passage_embeddings).astype("float32")
d = passage_embeddings.shape[1]
index = faiss.IndexFlatL2(d)
index.add(passage_embeddings)

# --- Provided Functions ---
def retrieve_passages(query, embedder, index, passages, top_k=3):
    query_embedding = embedder.encode([query], convert_to_tensor=False)
    query_embedding = np.array(query_embedding).astype('float32')
    distances, indices = index.search(query_embedding, top_k)
    retrieved = [passages[i] for i in indices[0]]
    return retrieved

def process_llm_output_with_references(text, passages):
    """
    Replace tokens like <<PASSAGE_1>> in the LLM output with HTML block quotes.
    """
    def replacement(match):
        num = int(match.group(1))
        if 0 <= num < len(passages):
            passage_text = passages[num - 1]
            return (f"<blockquote style='background: #ffffff; color: #000000; padding: 10px; "
                    f"border-left: 5px solid #ccc; margin: 10px 0; font-size: 14px;'>{passage_text}</blockquote>")
        return match.group(0)
    return re.sub(r"<<PASSAGE_(\d+)>>", replacement, text)

def generate_answer_with_references(query, retrieved_text):
    """
    Generate an answer using GPT-4 with reference tokens.
    """
    context_str = "\n".join([f"<<PASSAGE_{i}>>: \"{passage}\"" for i, passage in enumerate(retrieved_text)])
    messages = [
        {"role": "system", "content": "You are a knowledgeable technical assistant."},
        {"role": "user", "content": (
            f"Using the following textbook passages as reference:\n{context_str}\n\n"
            "In your answer, include passage block quotes as references. "
            "Refer to the passages using tokens such as <<PASSAGE_0>>, <<PASSAGE_1>>, etc. "
            "They should appear after complete thoughts on a new line.\n\n"
            f"Answer the question: {query}"
        )}
    ]
    response = client.chat.completions.create(
        model=MODEL,
        messages=messages,
    )
    answer = response.choices[0].message.content.strip()
    return answer

# --- Gradio App Function ---
def get_hybrid_output(query):
    retrieved = retrieve_passages(query, embedder, index, passages, top_k=3)
    hybrid_raw = generate_answer_with_references(query, retrieved)
    hybrid_processed = process_llm_output_with_references(hybrid_raw, retrieved)
    return f"<div style='white-space: pre-wrap;'>{hybrid_processed}</div>"

def clear_output():
    return ""

# --- Custom CSS ---
custom_css = """
body {
    background-color: #343541 !important;
    color: #ECECEC !important;
    margin: 0;
    padding: 0;
    font-family: 'Inter', sans-serif;
}
#container {
    max-width: 900px;
    margin: 0 auto;
    padding: 20px;
}
label {
    color: #ECECEC;
    font-weight: 600;
}
textarea, input {
    background-color: #40414F;
    color: #ECECEC;
    border: 1px solid #565869;
}
button {
    background-color: #565869;
    color: #ECECEC;
    border: none;
    font-weight: 600;
    transition: background-color 0.2s ease;
}
button:hover {
    background-color: #6e7283;
}
.output-box {
    border: 1px solid #565869;
    border-radius: 4px;
    padding: 10px;
    margin-top: 8px;
    background-color: #40414F;
}
"""

# --- Build Gradio Interface ---
with gr.Blocks(css=custom_css) as demo:
    with gr.Column(elem_id="container"):
        gr.Markdown("## Anonymous Chatbot\n### Loaded Article: Generative Agents - Interactive Simulacra of Human Behavior (Park et al. 2023)\n [https://arxiv.org/pdf/2304.03442](https://arxiv.org/pdf/2304.03442)")
        gr.Markdown("Enter any questions about the article above in the prompt!")
        query_input = gr.Textbox(label="Query", placeholder="Enter your query here...", lines=1)
        with gr.Column():
            submit_button = gr.Button("Submit")
            clear_button = gr.Button("Clear")
        output_box = gr.HTML(label="Output", elem_classes="output-box")
        
        submit_button.click(fn=get_hybrid_output, inputs=query_input, outputs=output_box)
        clear_button.click(fn=clear_output, inputs=[], outputs=output_box)

demo.launch()