import yaml
import fitz
import torch
import gradio as gr
from PIL import Image
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.vectorstores import Chroma
from langchain.chains import ConversationalRetrievalChain
from langchain.document_loaders import PyPDFLoader
from langchain.prompts import PromptTemplate
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
import spaces
from langchain_text_splitters import CharacterTextSplitter,RecursiveCharacterTextSplitter


class PDFChatBot:
    def __init__(self, config_path="config.yaml"):
        """
        Initialize the PDFChatBot instance.

        Parameters:
            config_path (str): Path to the configuration file (default is "../config.yaml").
        """
        self.processed = False
        self.page = 0
        self.chat_history = []
        # Initialize other attributes to None
        self.prompt = None
        self.documents = None
        self.embeddings = None
        self.vectordb = None
        self.tokenizer = None
        self.model = None
        self.pipeline = None
        self.chain = None
        self.chunk_size = 512
        self.overlap_percentage = 50
        self.max_chunks_in_context = 2
        self.current_context = None
        self.model_temperatue = 0.5
        self.format_seperator="""\n\n--\n\n"""
        self.pipe = None
        #self.chunk_size_slider = chunk_size_slider

    def load_embeddings(self):

        self.embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
        print("Embedding model loaded")

    def load_vectordb(self):
        overlap = int((self.overlap_percentage/100) * self.chunk_size)
        text_splitter = RecursiveCharacterTextSplitter(
            chunk_size=self.chunk_size,
            chunk_overlap=overlap,
            length_function=len,
            add_start_index=True,
        )
        docs = text_splitter.split_documents(self.documents)
        self.vectordb = Chroma.from_documents(docs, self.embeddings)
        print("Vector store created")
    @spaces.GPU
    def load_tokenizer(self):
        self.tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B-Instruct")

    @spaces.GPU
    def create_organic_pipeline(self):
        self.pipe = pipeline(
            "text-generation",
            model="meta-llama/Meta-Llama-3-8B-Instruct",
            model_kwargs={"torch_dtype": torch.bfloat16},
            device="cuda",
        )
        print("Model pipeline loaded")

    def get_organic_context(self, query):
        documents = self.vectordb.similarity_search_with_relevance_scores(query, k=self.max_chunks_in_context)
        context = self.format_seperator.join([doc.page_content for doc, score in documents])
        self.current_context = context
        print("Context Ready")
        print(self.current_context)
    @spaces.GPU
    def create_organic_response(self, history, query):
        self.get_organic_context(query)
        """
        pipe = pipeline(
            "text-generation",
            model="meta-llama/Meta-Llama-3-8B-Instruct",
            model_kwargs={"torch_dtype": torch.bfloat16},
            device="cuda",
        )
        """
        messages = [
            {"role": "system", "content": "From the the contained given below, answer the question of user \n " + self.current_context},
            {"role": "user", "content": query},
        ]

        prompt = self.pipe.tokenizer.apply_chat_template(
            messages,
            tokenize=False,
            add_generation_prompt=True
        )
        temp = 0.1
        outputs = self.pipe(
            prompt,
            max_new_tokens=1024,
            do_sample=True,
            temperature=temp,
            top_p=0.9,
        )
        print(outputs)
        return outputs[0]["generated_text"][len(prompt):]


    def process_file(self, file):
        """
        Process the uploaded PDF file and initialize necessary components: Tokenizer, VectorDB and LLM.

        Parameters:
            file (FileStorage): The uploaded PDF file.
        """
        self.documents = PyPDFLoader(file.name).load()
        self.load_embeddings()
        self.load_vectordb()
        self.create_organic_pipeline()
        #self.create_chain()
    @spaces.GPU
    def generate_response(self, history, query, file,chunk_size,chunk_overlap_percentage,model_temperature,max_chunks_in_context):

        self.chunk_size = chunk_size
        self.overlap_percentage = chunk_overlap_percentage
        self.model_temperatue = model_temperature
        self.max_chunks_in_context = max_chunks_in_context

        if not query:
            raise gr.Error(message='Submit a question')
        if not file:
            raise gr.Error(message='Upload a PDF')
        if not self.processed:
            self.process_file(file)
            self.processed = True



        result = self.create_organic_response(history="",query=query)
        for char in result:
            history[-1][-1] += char
        return history,""

    def render_file(self, file,chunk_size,chunk_overlap_percentage,model_temperature,max_chunks_in_context):
        print(chunk_size)
        doc = fitz.open(file.name)
        page = doc[self.page]
        self.chunk_size = chunk_size
        self.overlap_percentage = chunk_overlap_percentage
        self.model_temperatue = model_temperature
        self.max_chunks_in_context = max_chunks_in_context
        pix = page.get_pixmap(matrix=fitz.Matrix(300 / 72, 300 / 72))
        image = Image.frombytes('RGB', [pix.width, pix.height], pix.samples)
        return image

    def add_text(self, history, text):
        """
        Add user-entered text to the chat history.
        Parameters:
            history (list): List of chat history tuples.
            text (str): User-entered text.
        Returns:
            list: Updated chat history.
        """
        if not text:
            raise gr.Error('Enter text')
        history.append((text, ''))
        return history