from modules.chat.helpers import get_prompt
from modules.chat.chat_model_loader import ChatModelLoader
from modules.vectorstore.store_manager import VectorStoreManager
from modules.retriever.retriever import Retriever
from modules.chat.langchain.langchain_rag import (
    Langchain_RAG_V2,
    QuestionGenerator,
)


class LLMTutor:
    def __init__(self, config, user, logger=None):
        """
        Initialize the LLMTutor class.

        Args:
            config (dict): Configuration dictionary.
            user (str): User identifier.
            logger (Logger, optional): Logger instance. Defaults to None.
        """
        self.config = config
        self.llm = self.load_llm()
        self.user = user
        self.logger = logger
        self.vector_db = VectorStoreManager(config, logger=self.logger).load_database()
        self.qa_prompt = get_prompt(config, "qa")  # Initialize qa_prompt
        self.rephrase_prompt = get_prompt(
            config, "rephrase"
        )  # Initialize rephrase_prompt

        # TODO: Removed this functionality for now, don't know if we need it
        # if self.config["vectorstore"]["embedd_files"]:
        #     self.vector_db.create_database()
        #     self.vector_db.save_database()

    def update_llm(self, old_config, new_config):
        """
        Update the LLM and VectorStoreManager based on new configuration.

        Args:
            new_config (dict): New configuration dictionary.
        """
        changes = self.get_config_changes(old_config, new_config)

        if "llm_params.llm_loader" in changes:
            self.llm = self.load_llm()  # Reinitialize LLM if chat_model changes

        if "vectorstore.db_option" in changes:
            self.vector_db = VectorStoreManager(
                self.config, logger=self.logger
            ).load_database()  # Reinitialize VectorStoreManager if vectorstore changes

            # TODO: Removed this functionality for now, don't know if we need it
            # if self.config["vectorstore"]["embedd_files"]:
            #     self.vector_db.create_database()
            #     self.vector_db.save_database()

        if "llm_params.llm_style" in changes:
            self.qa_prompt = get_prompt(
                self.config, "qa"
            )  # Update qa_prompt if ELI5 changes

    def get_config_changes(self, old_config, new_config):
        """
        Get the changes between the old and new configuration.

        Args:
            old_config (dict): Old configuration dictionary.
            new_config (dict): New configuration dictionary.

        Returns:
            dict: Dictionary containing the changes.
        """
        changes = {}

        def compare_dicts(old, new, parent_key=""):
            for key in new:
                full_key = f"{parent_key}.{key}" if parent_key else key
                if isinstance(new[key], dict) and isinstance(old.get(key), dict):
                    compare_dicts(old.get(key, {}), new[key], full_key)
                elif old.get(key) != new[key]:
                    changes[full_key] = (old.get(key), new[key])
            # Include keys that are in old but not in new
            for key in old:
                if key not in new:
                    full_key = f"{parent_key}.{key}" if parent_key else key
                    changes[full_key] = (old[key], None)

        compare_dicts(old_config, new_config)
        return changes

    def retrieval_qa_chain(
        self, llm, qa_prompt, rephrase_prompt, db, memory=None, callbacks=None
    ):
        """
        Create a Retrieval QA Chain.

        Args:
            llm (LLM): The language model instance.
            qa_prompt (str): The QA prompt string.
            rephrase_prompt (str): The rephrase prompt string.
            db (VectorStore): The vector store instance.
            memory (Memory, optional): Memory instance. Defaults to None.

        Returns:
            Chain: The retrieval QA chain instance.
        """
        retriever = Retriever(self.config)._return_retriever(db)

        if self.config["llm_params"]["llm_arch"] == "langchain":
            self.qa_chain = Langchain_RAG_V2(
                llm=llm,
                memory=memory,
                retriever=retriever,
                qa_prompt=qa_prompt,
                rephrase_prompt=rephrase_prompt,
                config=self.config,
                callbacks=callbacks,
            )

            self.question_generator = QuestionGenerator()
        else:
            raise ValueError(
                f"Invalid LLM Architecture: {self.config['llm_params']['llm_arch']}"
            )
        return self.qa_chain

    def load_llm(self):
        """
        Load the language model.

        Returns:
            LLM: The loaded language model instance.
        """
        chat_model_loader = ChatModelLoader(self.config)
        llm = chat_model_loader.load_chat_model()
        return llm

    def qa_bot(self, memory=None, callbacks=None):
        """
        Create a QA bot instance.

        Args:
            memory (Memory, optional): Memory instance. Defaults to None.
            qa_prompt (str, optional): QA prompt string. Defaults to None.
            rephrase_prompt (str, optional): Rephrase prompt string. Defaults to None.

        Returns:
            Chain: The QA bot chain instance.
        """
        # sanity check to see if there are any documents in the database
        if len(self.vector_db) == 0:
            raise ValueError(
                "No documents in the database. Populate the database first."
            )

        qa = self.retrieval_qa_chain(
            self.llm,
            self.qa_prompt,
            self.rephrase_prompt,
            self.vector_db,
            memory,
            callbacks=callbacks,
        )

        return qa