# %%writefile app.py
from setup_code import *  # This imports everything from setup_code.py

general_greeting_num = 0
general_question_num = 1
machine_learning_num = 2
python_code_num = 3
obnoxious_num = 4
default_num = 5

query_classes = {'[General greeting]': general_greeting_num,
           '[General question]': general_question_num,
           '[Question about Machine Learning]': machine_learning_num,
           '[Question about Python code]' : python_code_num,
           '[Obnoxious statement]': obnoxious_num
}

query_classes_text = ", ".join(query_classes.keys())

class Classify_Agent:
    def __init__(self, openai_client) -> None:
        # TODO: Initialize the client and prompt for the Obnoxious_Agent
        self.openai_client = openai_client

    def classify_query(self, query):
        prompt = f"Please classify this query in angle brackets <{query}> as one of the following in square brackets only: {query_classes_text}."
        classification_response = get_completion(self.openai_client, prompt)

        if classification_response != None and classification_response in query_classes.keys():
            query_class = query_classes.get(classification_response, default_num)
            # st.write(f"query <{query}>: {classification_response}")

            return query_classes.get(classification_response, default_num)
        else:
            # st.write(f"query <{query}>: {classification_response}")
            return default_num

class Relevant_Documents_Agent:
    def __init__(self, openai_client) -> None:
        # TODO: Initialize the Relevant_Documents_Agent
        self.client = openai_client

    def get_relevance(self, conversation) -> str:
        pass

    def get_relevant_docs(self, conversation, docs) -> str: # uses Query Agent to get relevant docs
        pass

    def is_relevant(self, matches_text, user_query_plus_conversation) -> bool:
      prompt = f"Please confirm that the text in angle brackets: <{matches_text}>, is relevant to the text in double square brackets: [[{user_query_plus_conversation}]]. Return Yes or No"
      response = get_completion(self.client, prompt)

      return is_Yes(response)

class Query_Agent:
    def __init__(self, pinecone_index, pinecone_index_python, openai_client, embeddings) -> None:
        # TODO: Initialize the Query_Agent agent
        self.pinecone_index = pinecone_index
        self.pinecone_index_python = pinecone_index_python
        self.openai_client = openai_client
        self.embeddings = embeddings

    def get_openai_embedding(self, text, model="text-embedding-ada-002"):
        text = text.replace("\n", " ")
        return self.openai_client.embeddings.create(input=[text], model=model).data[0].embedding

    def query_vector_store(self, query, index=None, k=5) -> str:
        if index == None:
            index = self.pinecone_index

        query_embedding = self.get_openai_embedding(query)

        def get_namespace(index):
            stat = index.describe_index_stats()
            stat_dict_key = stat['namespaces'].keys()

            stat_dict_key_list = list(stat_dict_key)
            first_key = stat_dict_key_list[0]

            return first_key

        ns = get_namespace(index)

        matches_text = get_top_k_text(index.query(
            namespace=ns,
            top_k=k,
            vector=query_embedding,
            include_values=True,
            include_metadata=True
            )
        )
        return matches_text

class Answering_Agent:
    def __init__(self, openai_client) -> None:
        # TODO: Initialize the Answering_Agent
        self.client = openai_client

    def generate_response(self, query, docs, conv_history, selected_mode):
        # TODO: Generate a response to the user's query
        prompt_for_gpt = f"Based on this text in angle brackets: <{docs}>, please summarize a response to this query: {query} in the context of this conversation: {conv_history}. Please use language appropriate for a {selected_mode}."
        return get_completion(self.client, prompt_for_gpt)

    def generate_image(self, text):
        caption_prompt = f"Based on this text, repeated here in double square brackets for your reference: [[{text}]], please generate a simple caption that I can use with dall-e to generate an instructional image."
        caption_text = get_completion(self.client, caption_prompt)
        #st.write(caption_text)
        image = Head_Agent.text_to_image(self.client, caption_text)
        return image

class Head_Agent:
    def __init__(self, openai_key, pinecone_key) -> None:
        # TODO: Initialize the Head_Agent
        self.openai_key = openai_key
        self.pinecone_key = pinecone_key
        self.selected_mode = ""

        self.openai_client = OpenAI(api_key=self.openai_key)
        self.pc = Pinecone(api_key=self.pinecone_key)
        self.pinecone_index = self.pc.Index("index-600")
        self.pinecone_index_python = self.pc.Index("index-py-files")

        self.setup_sub_agents()

    def setup_sub_agents(self):
        # TODO: Setup the sub-agents
        self.classify_agent = Classify_Agent(self.openai_client)
        self.query_agent = Query_Agent(self.pinecone_index, self.pinecone_index_python, self.openai_client, None)  # Pass embeddings if needed
        self.answering_agent = Answering_Agent(self.openai_client)
        self.relevant_documents_agent = Relevant_Documents_Agent(self.openai_client)

    def process_query_response(self, user_query, query_topic):
        # Retrieve the history related to the query_topic
        conversation = []
        index = self.pinecone_index
        if query_topic == "ml":
            conversation = Head_Agent.get_history_about('ml')
        elif query_topic == 'python':
            conversation = Head_Agent.get_history_about('python')
            index = self.pinecone_index_python

        # get matches from Query_Agent, which uses Pinecone
        user_query_plus_conversation = f"The current query is: {user_query}"
        if len(conversation) > 0:
            conversation_text = "\n".join(conversation)
            user_query_plus_conversation += f'The current conversation is: {conversation_text}'
       
        # st.write(user_query_plus_conversation)
        matches_text = self.query_agent.query_vector_store(user_query_plus_conversation, index)

        if self.relevant_documents_agent.is_relevant(matches_text, user_query_plus_conversation):
            #maybe here we can ask GPT to make up an answer if there is no match
            response = self.answering_agent.generate_response(user_query, matches_text, conversation, self.selected_mode)
        else:
            response = "Sorry, I don't have relevant information to answer that query."

        return response

    @staticmethod
    def get_conversation():
        # ... (code for getting conversation history)
        return Head_Agent.get_history_about()

    @staticmethod
    def get_history_about(topic=None):
        history = []

        for message in st.session_state.messages:
            role = message["role"]
            content = message["content"]

            if topic == None:
                if role == "user":
                    history.append(f"{content} ")
            else:
                if message["topic"] == topic:
                    history.append(f"{content} ")

        # st.write(f"user history in get_conversation is {history}")

        if history != None:
            history = history[-2:]

        return history

    @staticmethod
    def text_to_image(openai_client, text):
        response = openai_client.images.generate(
            model="dall-e-3",
            prompt = text,
            n=1,
            size="1024x1024"
        )
        image_url = response.data[0].url
        with urllib.request.urlopen(image_url) as image_url:
            img = Image.open(BytesIO(image_url.read()))

        return img

    def main_loop_1(self):
        # TODO: Run the main loop for the chatbot
        st.title("Mini Project 2: Streamlit Chatbot")

        # Check for existing session state variables
        if "openai_model" not in st.session_state:
            # ... (initialize model)
            # st.session_state.openai_model = openai_client #'GPT-3.5-turbo'
            st.session_state.openai_model = 'gpt-3.5-turbo'

        if "messages" not in st.session_state:
            # ... (initialize messages)
            st.session_state.messages = []

        # Define the selection options
        modes = ['1st grade student', 'middle school student', 'high school student', 'college student', 'grad student']

        # Use st.selectbox to let the user select a mode
        self.selected_mode = st.selectbox("Select your education level:", modes)

        # Display existing chat messages
        # ... (code for displaying messages)
        for message in st.session_state.messages:
            if message["role"] == "assistant":
                with st.chat_message("assistant"):
                    st.write(message["content"])
                    if message['image'] != None:
                        st.image(message['image'])
            else:
                with st.chat_message("user"):
                    st.write(message["content"])

        # Wait for user input
        if user_query := st.chat_input("What would you like to chat about?"):
            # # ... (append user message to messages)

            # ... (display user message)
            with st.chat_message("user"):
                st.write(user_query)

            # Generate AI response
            with st.chat_message("assistant"):
                # ... (send request to OpenAI API)
                response = ""
                topic = None
                image = None
                hasImage = False

                # Get the current conversation with new user query to check for users' intension
                conversation = self.get_conversation()
                user_query_plus_conversation = f"The current query is: {user_query}. The current conversation is: {conversation}"
                classify_query = self.classify_agent.classify_query(user_query_plus_conversation)

                if classify_query == general_greeting_num:
                    response = "How can I assist you today?"
                elif classify_query == general_question_num:
                    response = "Please ask a question about Machine Learning or Python Code."
                elif classify_query == machine_learning_num:
                    # answering agent will 1. call query agent te get matches from pinecone, 2. verify the matches r relevant, 3. generate response
                    response = self.process_query_response(user_query, 'ml')

                    # answering agent will generate an image
                    if not contains_sorry(response):
                        image = self.answering_agent.generate_image(response)
                        hasImage = True
                        topic = "ml"

                elif classify_query == python_code_num:
                    response = self.process_query_response(user_query, 'python')
                    # answering agent will generate an image
                    if not contains_sorry(response):
                        image = self.answering_agent.generate_image(response)
                        hasImage = True
                        topic = "python"

                elif classify_query == obnoxious_num:
                    response = "Please dont be obnoxious."
                elif classify_query == default_num:
                    response = "I'm not sure how to respond to that."
                else:
                    response = "I'm not sure how to respond to that."

                # ... (get AI response and display it)
                st.write(response)
                if hasImage:
                    st.image(image)

                # Test moving append user_query down here:
                st.session_state.messages.append({"role": "user", "content": user_query, "topic": topic, "image": None})
                # ... (append AI response to messages)
                st.session_state.messages.append({"role": "assistant", "content": response, "topic": topic, "image": image})

if __name__ == "__main__":
    head_agent = Head_Agent(OPENAI_KEY, pc_apikey)
    head_agent.main_loop_1()