import os
import logging
from typing import List, Dict, Any, Optional
from pathlib import Path
# import torch
from dotenv import load_dotenv
from haystack_integrations.document_stores.qdrant import QdrantDocumentStore
from haystack_integrations.components.retrievers.qdrant import QdrantEmbeddingRetriever, QdrantSparseEmbeddingRetriever
from haystack.components.embedders import OpenAIDocumentEmbedder, OpenAITextEmbedder
from haystack.components.builders.prompt_builder import PromptBuilder
from haystack.components.joiners.document_joiner import DocumentJoiner
from haystack.components.preprocessors.document_cleaner import DocumentCleaner
# from haystack.components.rankers.transformers import TransformersRanker
from haystack.components.writers import DocumentWriter
from haystack.components.generators.openai import OpenAIGenerator
from haystack import Pipeline
from haystack.utils import Secret
from haystack import tracing
from haystack.tracing.logging_tracer import LoggingTracer

# Load environment variables
load_dotenv()

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# logging.basicConfig(format="%(levelname)s - %(name)s -  %(message)s", level=logging.WARNING)
# logging.getLogger("haystack").setLevel(logging.DEBUG)

# tracing.tracer.is_content_tracing_enabled = True # to enable tracing/logging content (inputs/outputs)
# tracing.enable_tracing(LoggingTracer(tags_color_strings={"haystack.component.input": "\x1b[1;31m", "haystack.component.name": "\x1b[1;34m"}))

class RAGPipeline:
    def __init__(
        self,
        embedding_model_name: str = "BAAI/bge-en-icl",
        llm_model_name: str = "meta-llama/Llama-3.3-70B-Instruct",
        qdrant_path: str = None
    ):
        self.embedding_model_name = embedding_model_name
        self.llm_model_name = llm_model_name
        self.qdrant_path = qdrant_path
        self.nebius_api_key = Secret.from_token(os.getenv("NEBIUS_API_KEY"))
        
        if not self.nebius_api_key:
            logger.warning("NEBIUS_API_KEY not found in environment variables")
        
        # Initialize document stores and components
        self.init_document_store()
        self.init_components()
        self.build_indexing_pipeline()
        self.build_query_pipeline()
        
    def init_document_store(self):
        """Initialize Qdrant document store for both vector and BM25 search"""
        # Qdrant store for both vector and BM25 search
        self.document_store = QdrantDocumentStore(
            path=self.qdrant_path,
            embedding_dim=4096,  # Dimension for BGE model
            recreate_index=False,
            on_disk=True,
            on_disk_payload=True,
            index="ltu_documents",
            force_disable_check_same_thread=True,
            use_sparse_embeddings=True  # Enable BM25 support
        )
    
    def init_components(self):
        """Initialize all components needed for the pipelines"""
        # Document processing
        self.document_cleaner = DocumentCleaner()
        
        # Embedding components
        self.document_embedder = OpenAIDocumentEmbedder(
            api_base_url="https://api.studio.nebius.com/v1/",
            model=self.embedding_model_name,
            api_key=self.nebius_api_key,
        )
        
        self.text_embedder = OpenAITextEmbedder(
            api_base_url="https://api.studio.nebius.com/v1/",
            model=self.embedding_model_name,
            api_key=self.nebius_api_key,
        )
        
        # Retrievers
        self.bm25_retriever = QdrantSparseEmbeddingRetriever(
            document_store=self.document_store,
            top_k=5
        )
        
        self.embedding_retriever = QdrantEmbeddingRetriever(
            document_store=self.document_store,
            top_k=5
        )
        
        # Document joiner for combining results
        self.document_joiner = DocumentJoiner()
        
        # Ranker for re-ranking combined results
        # self.ranker = TransformersRanker(
        #     model="cross-encoder/ms-marco-MiniLM-L-6-v2",
        #     top_k=5,
        #     device="cuda" if self.use_gpu else "cpu"
        # )
        
        # LLM components
        self.llm = OpenAIGenerator(
            api_base_url="https://api.studio.nebius.com/v1/",
            model=self.llm_model_name,
            api_key=self.nebius_api_key,
            generation_kwargs={
                "max_tokens": 1024,
                "temperature": 0.1,
                "top_p": 0.95,
            }
        )
        
        # Prompt builder
        self.prompt_builder = PromptBuilder(
            template="""
            <s>[INST] You are a helpful assistant that answers questions based on the provided context.
            
            Context:
            {% for document in documents %}
            {{ document.content }}
            {% endfor %}
            
            Question: {{ question }}
            
            Answer the question based only on the provided context. If the context doesn't contain the answer, say "I don't have enough information to answer this question."
            
            Answer: [/INST]
            """
        )
    
    def build_indexing_pipeline(self):
        """Build the pipeline for indexing documents"""
        self.indexing_pipeline = Pipeline()
        self.indexing_pipeline.add_component("document_cleaner", self.document_cleaner)
        self.indexing_pipeline.add_component("document_embedder", self.document_embedder)
        self.indexing_pipeline.add_component("document_writer", DocumentWriter(document_store=self.document_store))
        
        # Connect components
        self.indexing_pipeline.connect("document_cleaner", "document_embedder")
        self.indexing_pipeline.connect("document_embedder", "document_writer")
    
    def build_query_pipeline(self):
        """Build the pipeline for querying"""
        self.query_pipeline = Pipeline()
        
        # Add components
        self.query_pipeline.add_component("text_embedder", self.text_embedder)
        # self.query_pipeline.add_component("bm25_retriever", self.bm25_retriever)
        self.query_pipeline.add_component("embedding_retriever", self.embedding_retriever)
        # self.query_pipeline.add_component("document_joiner", self.document_joiner)
        # self.query_pipeline.add_component("ranker", self.ranker)
        self.query_pipeline.add_component("prompt_builder", self.prompt_builder)
        self.query_pipeline.add_component("llm", self.llm)
        
        # Connect components
        self.query_pipeline.connect("text_embedder.embedding", "embedding_retriever.query_embedding")
        # self.query_pipeline.connect("bm25_retriever", "document_joiner.documents_1")
        # self.query_pipeline.connect("embedding_retriever", "document_joiner.documents_2")
        # self.query_pipeline.connect("document_joiner", "ranker")
        # self.query_pipeline.connect("ranker", "prompt_builder.documents")
        self.query_pipeline.connect("embedding_retriever.documents", "prompt_builder.documents")
        self.query_pipeline.connect("prompt_builder.prompt", "llm")
    
    def index_documents(self, documents: List[Dict[str, Any]]):
        """
        Index documents in the document store.
        
        Args:
            documents: List of documents to index
        """
        logger.info(f"Indexing {len(documents)} documents")
        
        try:
            self.indexing_pipeline.run(
                {"document_cleaner": {"documents": documents}}
                )
            logger.info("Indexing completed successfully")
        except Exception as e:
            logger.error(f"Error during indexing: {e}")
    
    def query(self, question: str, top_k: int = 5) -> Dict[str, Any]:
        """
        Query the RAG pipeline with a question.
        
        Args:
            question: The question to ask
            top_k: Number of documents to retrieve
            
        Returns:
            Dictionary containing the answer and retrieved documents
        """
        logger.info(f"Querying with question: {question}")
        
        try:
            # Update top_k for retrievers
            self.bm25_retriever.top_k = top_k
            self.embedding_retriever.top_k = top_k
            
            # Run the query pipeline
            result = self.query_pipeline.run({
                "text_embedder": {"text": question},
                # "bm25_retriever": {"query": question},
                "prompt_builder": {"question": question}
            }, {'embedding_retriever'})
            
            # Extract answer and documents
            answer = result["llm"]["replies"][0]
            print(result.keys())
            documents = result["embedding_retriever"]["documents"]
            
            return {
                "answer": answer,
                "documents": documents, #documents,
                "question": question
            }
        except Exception as e:
            logger.error(f"Error during query: {e}")
            return {
                "answer": f"An error occurred: {str(e)}",
                "documents": [],
                "question": question
            }
    
    def get_document_count(self) -> int:
        """
        Get the number of documents in the document store.
        
        Returns:
            Document count
        """
        return self.document_store.count_documents()