import os
import re
from dotenv import load_dotenv
load_dotenv()

import gradio as gr

from langchain.agents.openai_assistant import OpenAIAssistantRunnable
from langchain.schema import HumanMessage, AIMessage

api_key = os.getenv('OPENAI_API_KEY')
extractor_agent = os.getenv('ASSISTANT_ID_SOLUTION_SPECIFIER_B')

# Create the assistant. By default, we don't specify a thread_id,
# so the first call that doesn't pass one will create a new thread.
extractor_llm = OpenAIAssistantRunnable(
    assistant_id=extractor_agent,
    api_key=api_key,
    as_agent=True
)

# We will store thread_id globally or in a session variable.
THREAD_ID = None

def remove_citation(text):
    pattern = r"【\d+†\w+】"
    return re.sub(pattern, "📚", text)

def predict(message, history):
    """
    Receives the new user message plus the entire conversation history 
    from Gradio. If no thread_id is set, we create a new thread. 
    Otherwise we pass the existing thread_id.
    """
    global THREAD_ID
    
    # debug print
    print("current history:", history)
    
    # If history is empty, this means that it is probably a new conversation and therefore the thread shall be reset
    if not history:
        THREAD_ID = None
    
    # 1) Decide if we are creating a new thread or continuing the old one
    if THREAD_ID is None:
        # No thread_id yet -> this is the first user message
        response = extractor_llm.invoke({"content": message})
        THREAD_ID = response.thread_id  # store for subsequent calls
    else:
        # We already have a thread_id -> continue that same thread
        response = extractor_llm.invoke({"content": message, "thread_id": THREAD_ID})
    
    # 2) Extract the text output from the response
    output = response.return_values["output"]
    non_cited_output = remove_citation(output)
    
    # 3) Return the model's text to display in Gradio
    return non_cited_output

# Create a Gradio ChatInterface using our predict function
chat = gr.ChatInterface(
    fn=predict, 
    title="Solution Specifier B", 
    #description="Testing threaded conversation"
)
chat.launch(share=True)