from haystack import Pipeline
from haystack.components.builders import PromptBuilder
from haystack.components.embedders import SentenceTransformersTextEmbedder
from haystack.components.generators import OpenAIGenerator
from haystack.utils import Secret
from haystack_integrations.components.retrievers.qdrant import QdrantEmbeddingRetriever
from src.settings import settings


class RAGPipeline:
    def __init__(
        self,
        document_store,
        template: str,
        top_k: int,
    ) -> None:
        self.text_embedder: SentenceTransformersTextEmbedder  # type: ignore
        self.retriever: QdrantEmbeddingRetriever  # type: ignore
        self.prompt_builder: PromptBuilder  # type: ignore
        self.llm_provider: OpenAIGenerator  # type: ignore
        self.pipeline: Pipeline | None = None
        self.document_store = document_store
        self.template = template
        self.top_k = top_k

        self.get_text_embedder()
        self.get_retriever()
        self.get_prompt_builder()
        self.get_llm_provider()

    def run(self, query: str, filter_selections: dict[str, list] | None = None) -> dict:
        if not self.pipeline:
            self.build_pipeline()
        if self.pipeline:
            filters = RAGPipeline.build_filter(filter_selections=filter_selections)
            result = self.pipeline.run(
                data={
                    "text_embedder": {"text": query},
                    "retriever": {"filters": filters},
                    "prompt_builder": {"query": query},
                },
                include_outputs_from=["retriever", "llm"],
            )
        return result

    def get_text_embedder(self) -> None:
        self.text_embedder = SentenceTransformersTextEmbedder(
            model=settings.qdrant_database.model
        )
        self.text_embedder.warm_up()

    def get_retriever(self) -> None:
        self.retriever = QdrantEmbeddingRetriever(
            document_store=self.document_store, top_k=self.top_k
        )

    def get_prompt_builder(self) -> None:
        self.prompt_builder = PromptBuilder(template=self.template)

    def get_llm_provider(self) -> None:
        self.llm_provider = OpenAIGenerator(
            model=settings.llm_provider.model,
            api_key=Secret.from_env_var("LLM_PROVIDER__API_KEY"),
            max_retries=3,
            generation_kwargs={"max_tokens": 5000, "temperature": 0.2},
        )

    @staticmethod
    def build_filter(filter_selections: dict[str, list] | None = None) -> dict:
        filters: dict[str, str | list[dict]] = {"operator": "AND", "conditions": []}
        if filter_selections:
            for meta_data_name, selections in filter_selections.items():
                filters["conditions"].append(  # type: ignore
                    {
                        "field": "meta." + meta_data_name,
                        "operator": "in",
                        "value": selections,
                    }
                )
        else:
            filters = {}
        return filters

    def build_pipeline(self):
        self.pipeline = Pipeline()
        self.pipeline.add_component("text_embedder", self.text_embedder)
        self.pipeline.add_component("retriever", self.retriever)
        self.pipeline.add_component("prompt_builder", self.prompt_builder)
        self.pipeline.add_component("llm", self.llm_provider)

        self.pipeline.connect("text_embedder.embedding", "retriever.query_embedding")
        self.pipeline.connect("retriever", "prompt_builder.documents")
        self.pipeline.connect("prompt_builder", "llm")


if __name__ == "__main__":
    document_store = DocumentStore(index="inc_data")

    with open("src/rag/prompt_templates/inc_template.txt", "r") as file:
        template = file.read()

    pipeline = RAGPipeline(
        document_store=document_store.document_store, template=template, top_k=5
    )
    filter_selections = {
        "author": ["Malaysia", "Australia"],
    }
    result = pipeline.run(
        "What is Malaysia's position on plastic waste?",
        filter_selections=filter_selections,
    )
    pass