import streamlit as st
import os
from streamlit_chat import message
import numpy as np
import pandas as pd
from io import StringIO
import PyPDF2
from tqdm.auto import tqdm
import math
from transformers import pipeline
from langchain.prompts import ChatPromptTemplate
from langchain_community.llms import HuggingFaceHub
from langchain.chains.summarize import load_summarize_chain
import re
from dotenv import load_dotenv
# import json

# st.config(PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION="python")

# from datasets import load_dataset

# dataset = load_dataset("wikipedia", "20220301.en", split="train[240000:250000]")


# wikidata = []

# for record in dataset:
#     wikidata.append(record["text"])

# wikidata = list(set(wikidata))
# # print("\n".join(wikidata[:5]))
# # print(len(wikidata))

from sentence_transformers import SentenceTransformer
import torch

device = 'cuda' if torch.cuda.is_available() else 'cpu'

if device != 'cuda':
    st.markdown(f"Note: Using {device}. Expected slow responses compare to CUDA-enabled GPU. Please be patient thanks")

model = SentenceTransformer("all-MiniLM-L6-v2", device=device)
st.divider()

# Creating a Index(Pinecone Vector Database)
import os
# import pinecone

from pinecone.grpc import PineconeGRPC


PINECONE_API_KEY=os.getenv("PINECONE_API_KEY")
PINECONE_ENV=os.getenv("PINECONE_ENV")
PINECONE_ENVIRONMENT=os.getenv("PINECONE_ENVIRONMENT")

# pc = PineconeGRPC( api_key=os.environ.get("PINECONE_API_KEY") ) # Now do stuff if 'my_index' not in pc.list_indexes().names(): pc.create_index( name='my_index', dimension=1536, metric='euclidean', spec=ServerlessSpec( cloud='aws', region='us-west-2' ) )

# Load environment variables from .env file
load_dotenv()
HUGGINGFACEHUB_API_TOKEN = os.getenv("HUGGINGFACEHUB_API_TOKEN")

def connect_pinecone():
    pinecone = PineconeGRPC(api_key=PINECONE_API_KEY, environment=PINECONE_ENV)
    # st.code(pinecone)
    # st.divider()
    # st.text(pinecone.list_indexes().names())
    # st.divider()
    # st.text(f"Succesfully connected to the pinecone")
    return pinecone

def get_pinecone_semantic_index(pinecone):
    index_name = "sematic-search-index"

    # only create if it deosnot exists
    if index_name not in pinecone.list_indexes().names():
        pinecone.create_index(
            name=index_name,
            description="Semantic search",
            dimension=model.get_sentence_embedding_dimension(),
            metric="cosine",
            spec=ServerlessSpec( cloud='aws', region='us-east-1' )
        )
    # now connect to index
    index = pinecone.Index(index_name)
    # st.text(f"Succesfully connected to the pinecone index")
    return index



def prompt_engineer(text, longtext, query):
    summary_prompt_inst = """
    write a concise summary of the following text delimited by triple backquotes.
    return your response in bullet points which convers the key points of the text.

    ```{context}```

    BULLET POINT SUMMARY:
    """
    # Load the summarization pipeline with the specified model
    # summarizer = pipeline("summarization", model="facebook/bart-large-cnn")

    # Generate the prompt
    # prompt = summary_prompt_template.format(text=text)

    # Generate the summary
    # summary = summarizer(prompt, max_length=1024, min_length=50)[0]["summary_text"]

    # try:
    #     sllm = HuggingFaceHub(
    #         repo_id="meta-llama/Meta-Llama-3-8B-Instruct", model_kwargs={"temperature": 0.1, "max_new_tokens": 256, "task":"summarization"}
    #     )
    #     st.write("Summary Chat llm connection started..")
    # except Exception as e:
    #     st.error(f"Error invoke: {e}")

    # from langchain.chains.combine_documents import create_stuff_documents_chain
    # from langchain.chains.llm import LLMChain
    # from langchain_core.prompts import ChatPromptTemplate
    # from langchain_core.documents import Document

    # docs =  Document(page_content=longtext, metadata={"source": "pinecone"})
    # st.write(docs)
    # # Define prompt
    # prompt = ChatPromptTemplate.from_messages(
    #     [("system", summary_prompt_template)]
    # )
    
    # # Instantiate chain
    # chain = create_stuff_documents_chain(sllm, prompt)
    
    # # Invoke chain
    # summary = chain.invoke({"context": [docs]})

    summary_prompt_template = ChatPromptTemplate.from_template(summary_prompt_inst)
    summary_prompt = summary_prompt_template.format(context=longtext, question="generate summary of text?")
    
    with st.sidebar:
        st.divider()
        # st.markdown("*:red[Text Summary Generation]* from above Top 5 **:green[similarity search results]**.")


    GENERATION_PROMPT_TEMPLATE = """
    Instructions:
    -------------------------------------------------------------------------------------------------------------------------------
    Answer the question only based on the below context:
    - You're a Research AI expert in the explaining and reading the research papers.
    - Questions with out-of-context replay with The question is out of context. 
    - Always try to provide Keep it simple answers in nice format without incomplete sentence.
    - Give the answer atleast 5 seperate lines addition to the title info.
    - Only If question is relevent to context provide Doc Title: <title> Paragraph: <Paragraph> Page No: <pagenumber> 
    -------------------------------------------------------------------------------------------------------------------------------
    {context}
    -------------------------------------------------------------------------------------------------------------------------------
    Answer the question based on the above context: {question}
    """

    prompt_template = ChatPromptTemplate.from_template(GENERATION_PROMPT_TEMPLATE)
    prompt = prompt_template.format(context=text, question=query)
    response_text = ""
    result = ""
    
    try:
        llm = HuggingFaceHub(
            repo_id="meta-llama/Meta-Llama-3-8B-Instruct", model_kwargs={"temperature": 0.1, "task":"text-generation"}
        )
        st.write("GEN llm connection started..")
        # summary = llm.invoke(summary_prompt)
        # st.write(summary)
        # st.divider()
        response_text = llm.invoke(prompt)
        escaped_query = re.escape(query)
        result = re.split(f'Answer the question based on the above context: {escaped_query}\n',response_text)[-1]
        st.write("reponse generated see chat window 👉🏻")
        st.divider()
    except Exception as e:
        st.error(f"Error invoke: {e}")

    return result

def chat_actions():
    
    pinecone = connect_pinecone()
    index = get_pinecone_semantic_index(pinecone)

    st.session_state["chat_history"].append(
        {"role": "user", "content": st.session_state["chat_input"]},
    )

    query = st.session_state["chat_input"]
    query_embedding = model.encode(query)
    # create the query vector
    query_vector = query_embedding.tolist()
    # now query vector database
    result = index.query(query_vector, top_k=5, include_metadata=True)  # result is a list of tuples

    # Create a list of lists
    data = []
    consolidated_text = ""
    i = 0
    for res in result['matches']:
        i = i + 1
        data.append([f"{i}⭐", res['score'], res['metadata']['text']])
        consolidated_text += res['metadata']['text']

    # Create a DataFrame from the list of lists
    resdf = pd.DataFrame(data, columns=['TopRank', 'Score', 'Text'])

    with st.sidebar:
        st.markdown("*:red[semantic search results]* with **:green[Retrieval Augmented Generation]** ***(RAG)***.")
        st.dataframe(resdf)
        bytesize = consolidated_text.encode("utf-8")
        p = math.pow(1024, 2)
        mbsize = round(len(bytesize) / p, 2)
        st.write(f"Text length of {len(consolidated_text)} characters with {mbsize}MB size")
        response = prompt_engineer(consolidated_text[:1024], consolidated_text, query)

    for res in result['matches']:
        st.session_state["chat_history"].append(
            {
                "role": "assistant",
                "content": f"{response}",
            },  # This can be replaced with your chat response logic
        )
        break;

if "chat_history" not in st.session_state:
    st.session_state["chat_history"] = []

st.chat_input("show me the contents of ML paper published on xxx with article no. xx?", on_submit=chat_actions, key="chat_input")

for i in st.session_state["chat_history"]:
    with st.chat_message(name=i["role"]):
        st.write(i["content"])

def print_out(pages):
    for i in range(len(pages)):
        text = pages[i].extract_text().strip()
        st.write(f"Page {i} : {text}")

def combine_text(pages):
    concatenates_text = ""
    for page in tqdm(pages):
        text = page.extract_text().strip()
        concatenates_text += text
    bytesize = concatenates_text.encode("utf-8")
    p = math.pow(1024, 2)
    mbsize = round(len(bytesize) / p, 2)
    st.write(f"There are {len(concatenates_text)} characters in the pdf with {mbsize}MB size")
    return concatenates_text

def split_into_chunks(text, chunk_size):

    chunks = []
    for i in range(0, len(text), chunk_size):
        chunks.append(text[i:i + chunk_size])

    return chunks

def create_embeddings():
    # Get the uploaded file
    inputtext = ""
    with st.sidebar:
        uploaded_files = st.session_state["uploaded_files"]
        for uploaded_file in uploaded_files:
            # Read the contents of the file
            reader = PyPDF2.PdfReader(uploaded_file)
            pages = reader.pages
            print_out(pages)
            inputtext = combine_text(pages)

    # connect to pinecone index
    pinecone = connect_pinecone()
    index = get_pinecone_semantic_index(pinecone)

    # The maximum metadata size per vector is 40KB ~ 40000Bytes ~ each text character is 1 to 2 bytes. so rougly given chunk size of 10000 to 40000
    chunk_size = 10000
    batch_size = 2
    chunks = split_into_chunks(inputtext, chunk_size)

    for i in tqdm(range(0, len(chunks), batch_size)):
        # find end of batch
        end = min(i + batch_size, len(chunks))
        # create ids batch
        ids = [str(i) for i in range(i, end)]
        # create metadata batch
        metadata = [{"text": text} for text in chunks[i:end]]
        # create embeddings
        xc = model.encode(chunks[i:end])
        # create records list for upsert
        records = zip(ids, xc, metadata)
        # upsert records
        index.upsert(vectors=records)

    with st.sidebar:
        st.write("created vector embeddings!")
        # check no of records in the index
        st.write(f"{index.describe_index_stats()}")


    # Display the contents of the file
    # st.write(file_contents)

with st.sidebar:
    st.markdown("""
    ***:red[Follow this steps]***
    - upload pdf file to create embeddings using model on your own docs
    - wait see success message on embeddings creation 
    - It Takes couple of mins after upload the pdf
    - Now Chat with your documents with help of this RAG system 
    - It Generate Promted reponses on the upload pdf
    - Provides summarized results and QA's using GPT models
    - This system already trained on some wikipedia datasets too
    """)
    uploaded_files = st.file_uploader('Choose your .pdf file', type="pdf", accept_multiple_files=True, key="uploaded_files", on_change=create_embeddings)
    # for uploaded_file in uploaded_files:
        # To read file as bytes:
        # bytes_data = uploaded_file.getvalue()
        # st.write(bytes_data)

        # To convert to a string based IO:
        # stringio = StringIO(uploaded_file.getvalue().decode("utf-8"))
        # st.write(stringio)

        # To read file as string:
        # string_data = stringio.read()
        # st.write(string_data)

        # Can be used wherever a "file-like" object is accepted:
        # dataframe = pd.read_csv(uploaded_file)
        # st.write(dataframe)

        # reader = PyPDF2.PdfReader(uploaded_file)
        # pages = reader.pages
        # print_out(pages)
        # combine_text(pages)
        # promt_engineer(text)