import os
import streamlit as st
from streamlit_chat import message
from langchain_openai import OpenAIEmbeddings
from pinecone import Pinecone
import time
from langchain_pinecone.vectorstores import Pinecone as PineconeVectorStore
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.runnables import RunnableParallel, RunnablePassthrough
from langchain_openai import ChatOpenAI
from langchain_groq import ChatGroq
from langchain_anthropic import ChatAnthropic
from langchain_core.messages import AIMessage, HumanMessage, get_buffer_string
from langchain.memory import ConversationBufferMemory
from langchain_core.runnables import RunnableLambda
from operator import itemgetter

# Streamlit App Configuration
st.set_page_config(page_title="Docu-Help")

# Dropdown for namespace selection
namespace_name = st.sidebar.selectbox("Select Website:", ('crawlee', ''), key='namespace_name')

# Read API keys from environment variables
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
PINE_API_KEY = os.getenv("PINE_API_KEY")
LANGCHAIN_API_KEY = os.getenv("LANGCHAIN_API_KEY")
LANGCHAIN_TRACING_V2 = 'true'
LANGCHAIN_ENDPOINT = "https://api.smith.langchain.com"
LANGCHAIN_PROJECT = "docu-help"

# Sidebar for model selection and Pinecone index name input
st.sidebar.title("Sidebar")
model_name = st.sidebar.radio("Choose a model:", ("gpt-3.5-turbo-1106", "gpt-4-0125-preview", "Claude-Sonnet", "mixtral-groq"))
openai_api_key2 = st.sidebar.text_input("Enter OpenAI Key: ")
groq_api_key = st.sidebar.text_input("Groq API Key: ")
anthropic_api_key = st.sidebar.text_input("Claude API Key: ")
pinecone_index_name = os.getenv("pinecone_index_name")
namespace_name = "crawlee"

# Initialize session state variables if they don't exist
if 'generated' not in st.session_state:
    st.session_state['generated'] = []

if 'past' not in st.session_state:
    st.session_state['past'] = []

if 'messages' not in st.session_state:
    st.session_state['messages'] = [{"role": "system", "content": "You are a helpful assistant."}]

if 'total_cost' not in st.session_state:
    st.session_state['total_cost'] = 0.0
        
def refresh_text():
    with response_container:
        for i in range(len(st.session_state['past'])):
            try:
                user_message_content = st.session_state["past"][i]
                message = st.chat_message("user")
                message.write(user_message_content)
            except:
                print("Past error")
            
            try:
                ai_message_content = st.session_state["generated"][i]
                message = st.chat_message("assistant")
                message.write(ai_message_content)
            except:
                print("Generated Error")

# Function to generate a response using App 2's functionality
def generate_response(prompt):
    st.session_state['messages'].append({"role": "user", "content": prompt})
    embed = OpenAIEmbeddings(model="text-embedding-3-small", openai_api_key=OPENAI_API_KEY)

    pc = Pinecone(api_key=PINE_API_KEY)
    index = pc.Index(pinecone_index_name)
    time.sleep(1)  # Ensure index is ready
    index.describe_index_stats()

    vectorstore = PineconeVectorStore(index, embed, "text", namespace=namespace_name)
    retriever = vectorstore.as_retriever()

    template = """You are an expert software developer who specializes in APIs. Answer the user's question based only on the following context:
                {context}

                Chat History:
                {chat_history}

                Question: {question}
                """
    prompt_template = ChatPromptTemplate.from_template(template)

    if model_name == "Claude-Sonnet":
        chat_model = ChatAnthropic(temperature=0, model="claude-3-sonnet-20240229", anthropic_api_key=anthropic_api_key)
    elif model_name == "mixtral-groq":
        chat_model = ChatGroq(temperature=0, groq_api_key=groq_api_key, model_name="mixtral-8x7b-32768")
    else:
        chat_model = ChatOpenAI(temperature=0, model=model_name, openai_api_key=openai_api_key2)

    memory = ConversationBufferMemory(
        return_messages=True, output_key="answer", input_key="question"
    )

    # Loading the previous chat messages into memory
    for i in range(len(st.session_state['generated'])):
        # Replaced "Answer: " with "" to stop the model from learning to add "Answer: " to the beginning by itself
        memory.save_context({"question": st.session_state["past"][i]}, {"answer": st.session_state["generated"][i].replace("Answer: ", "")})

    # Prints the memory that the model will be using
    print(f"Memory: {memory.load_memory_variables({})}")

    rag_chain = (
        RunnablePassthrough.assign(context=(lambda x: x["context"]), chat_history=lambda x: get_buffer_string(x["chat_history"]))
        | prompt_template
        | chat_model
        | StrOutputParser()
    )
    
    rag_chain_with_source = RunnableParallel(
        {"context": retriever, "question": RunnablePassthrough(), "chat_history": RunnableLambda(memory.load_memory_variables) | itemgetter("history")}
    ).assign(answer=rag_chain)

    # Function that extracts the individual tokens from the output of the model
    def make_stream():
        sources = []
        st.session_state['generated'].append("Answer: ")
        yield st.session_state['generated'][-1]

        for chunk in rag_chain_with_source.stream(prompt):

            if list(chunk.keys())[0] == 'answer':
                st.session_state['generated'][-1] += chunk['answer']
                yield chunk['answer']

            elif list(chunk.keys())[0] == 'context':
                # sources = chunk['context']
                sources = [doc.metadata['source'] for doc in chunk['context']]

        sources_txt = "\n\nSources:\n" + "\n".join(sources)
        st.session_state['generated'][-1] += sources_txt
        yield sources_txt

    # Sending the message as a stream using the function above
    print("Running the response streamer...")
    with response_container:
        message = st.chat_message("assistant")
        my_generator = make_stream()
        message.write_stream(my_generator)

    

    formatted_response = st.session_state['generated'][-1]

    #response = rag_chain_with_source.invoke(prompt)

    #sources = [doc.metadata['source'] for doc in response['context']]

    #answer = response['answer']  # Extracting the 'answer' part

    #formatted_response = f"Answer: {answer}\n\nSources:\n" + "\n".join(sources)

    st.session_state['messages'].append({"role": "assistant", "content": formatted_response})

    return formatted_response

# Container for chat history and text box
response_container = st.container()
container = st.container()

# Implementing chat input as opposed to a form because chat_input stays locked at the bottom
if prompt := st.chat_input("Ask a question..."):
        # I moved reponse here because, for some reason, I get an error if I only have an if statement for user_input later...
        st.session_state['past'].append(prompt)
        refresh_text()

        response = generate_response(prompt)