import gradio as gr
from langchain.document_loaders import ArxivLoader
from PyPDF2 import PdfReader
from langchain_community.llms import HuggingFaceHub
from langchain.text_splitter import TokenTextSplitter
from langchain.chains.summarize import load_summarize_chain
from langchain.document_loaders import PyPDFLoader
from transformers import pipeline

from dotenv import load_dotenv
import os

load_dotenv()
hugging_api_key = os.getenv('HUGGING_API_KEY')

from groq import AsyncGroq
from groq import Groq

from langchain_groq import ChatGroq
from langchain.document_loaders import ArxivLoader
from langchain.vectorstores import Chroma
from langchain.chains import RetrievalQA
from langchain.embeddings.huggingface_hub import HuggingFaceHubEmbeddings
from huggingface_hub import login
login(hugging_api_key)
embedding_model = HuggingFaceHubEmbeddings(huggingfacehub_api_token=hugging_api_key)
llm = ChatGroq(temperature=0, model_name="llama3-70b-8192", api_key = "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")

def display_results(result):
    return "\n".join(result)  # Join each entry with double newlines for better readability

def summarize_pdf(pdf_file_path, max_length):
    # summarizer = pipeline('summarization', model='allenai/led-large-16384-arxiv', min_length=100, max_length=max_length, device=0)
    loader = PdfReader(pdf_file_path)
    text = """ """
    for page in loader.pages:
        text += page.extract_text()

    text_splitter = TokenTextSplitter(chunk_size=8192, chunk_overlap=1000)
    chunks = text_splitter.split_text(text)
    summary = ""
    for i in range(len(chunks)):
        # text = chunks[i].page_content
        text = chunks[i]
        summary += summarize_text(text)
    # summary = str(max_length)
    return summary

def summarize_text(text):
    sum_client = Groq(api_key="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")
    messages = []
    # messages.append({"role": "system", "content": "You are arxiv paper summarizer. If I give you the doi number, you should only output summarization. Summarization should be more than 10% words of the paper. For example, in the paper there are 500 words, than summarization should be more than 50 words."})
    messages.append({"role": "system", "content": "You are summarizer. If I give you the whole text you should summarize it.  And you don't need the title and author"})
    messages = messages + [
        {
            "role": "user",
            "content": f"Summarize the paper. The whole text is {text}",
        },
    ]
    response = sum_client.chat.completions.create(
        messages=messages,
        model="llama3-70b-8192",
        temperature=0,
        max_tokens=8192,
        top_p=1,
        stop=None
    )
    text_summary = response.choices[0].message.content
    return text_summary




def remove_first_sentence_and_title(text):
    # Remove the first sentence
    first_sentence_end = text.find('. ') + 2  # Find the end of the first sentence
    text_without_first_sentence = text[first_sentence_end:]

    # Remove the title
    title_start = text_without_first_sentence.find('**Title:**')
    if title_start != -1:
        title_end = text_without_first_sentence.find('\n', title_start)
        if title_end != -1:
            text_without_title = text_without_first_sentence[:title_start] + text_without_first_sentence[title_end+1:]
        else:
            text_without_title = text_without_first_sentence[:title_start]
    else:
        text_without_title = text_without_first_sentence

    return text_without_title.strip()



def summarize_arxiv_pdf(query):
    loader = ArxivLoader(query=query, load_max_docs=10)
    documents = loader.load()
    text_splitter = TokenTextSplitter(chunk_size=5700, chunk_overlap=100)
    chunks = text_splitter.split_documents(documents)

    text = documents[0].page_content


    ref_summary = ""
    for i in range(len(chunks)):
        text = chunks[i].page_content
        ref_summary += summarize_text(text)
    # ref_summary = ref_summary.split('paper:')[1]
    # ref_summary = remove_first_sentence_and_title(ref_summary)
    ref_summary = ref_summary.replace("Here is a summary of the paper:", "").strip()
    arxiv_summary = loader.get_summaries_as_docs()
    
    summaries = []
    for doc in arxiv_summary:
        title = doc.metadata.get("Title")
        authors = doc.metadata.get("Authors")
        url = doc.metadata.get("Entry ID")
        summary = doc.page_content
        summaries.append(f"**{title}**\n")
        summaries.append(f"**Authors:** {authors}\n")
        summaries.append(f"**View full paper:** [Link to paper]({url})\n")
        summaries.append(f"**Summary:** {summary}\n")
        summaries.append(f"**Lazyman Summary:**\n ")
        summaries.append(f"{ref_summary}")
    summaries = display_results(summaries)
    print(summaries)
    return summaries


client = AsyncGroq(api_key="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")

async def chat_with_replit(message, history):
    messages = []

    for chat in history:
        user = str(chat[0])
        assistant = str(chat[1])
    
        messages.append({"role": "system", "content": "You are assistor. I will ask you some questions than you should answer!"})
        messages.append({"role": 'user', "content": user})
        messages.append({"role": 'assistant', "content": assistant})

    messages = messages + [
        {
            "role": "user",
            "content": str(message),
        },
    ]

    print(messages)
    
    response_content = ""
    stream = await client.chat.completions.create(
        messages=messages,
        model="llama3-70b-8192",
        temperature=0,
        max_tokens=1024,
        top_p=1,
        stop=None,
        stream=True,
    )
    async for chunk in stream:
        content = chunk.choices[0].delta.content
        if content:
            response_content += chunk.choices[0].delta.content
        yield response_content

js = """<script src="https://replit.com/public/js/replit-badge-v2.js" theme="dark" position="bottom-right"></script>"""


async def chat_with_replit_pdf(message, history, doi_num):
    messages = []

    old_doi = "old"
    if old_doi != doi_num:
        loader = ArxivLoader(query=str(doi_num), load_max_docs=10)
        documents = loader.load_and_split()
        metadata = documents[0].metadata
        vector_store = Chroma.from_documents(documents, embedding_model)
        old_doi = doi_num
    def retrieve_relevant_content(user_query):
        results = vector_store.similarity_search(user_query, k=3)
        relevant_content = "\n\n".join([doc.page_content for doc in results])
        return relevant_content
    relevant_content = retrieve_relevant_content(message)


    messages = messages + [
        {
            "role": "user",
            "content": str(message),
        },
        {
            "role": "system",
            "content": f"You should answer about this arxiv paper for {doi_num}.\n" 
            f"This is the metadata of the paper:{metadata}.\n"
            f"This is relevant information of the paper:{relevant_content}.\n"
        }
    ]

    print(messages)
    
    response_content = ""
    stream = await client.chat.completions.create(
        messages=messages,
        model="llama3-70b-8192",
        temperature=0,
        max_tokens=1024,
        top_p=1,
        stop=None,
        stream=False,
    )
    return stream.choices[0].message.content;


with gr.Blocks() as app:
    with gr.Tab(label="Arxiv summarization"):
        with gr.Column():
            number = gr.Textbox(label="Enter your arxiv number")
            sumarxiv_btn = gr.Button(value="summarize-arxiv")
        with gr.Column():
            outputs = gr.Markdown(label="Summary", height=1000)
    sumarxiv_btn.click(summarize_arxiv_pdf, inputs=number, outputs=outputs)    
    with gr.Tab(label="Local summarization"):
        with gr.Row():
            with gr.Column():
                input_path = gr.File(label="Upload PDF file")
            with gr.Column():
                # set_temperature = gr.Slider(0, 1, value=0, step=0.1, label="temperature")
                set_max_length = gr.Slider(512, 4096, value=2048, step=512, label="max length")
                sumlocal_btn = gr.Button(value="summarize-local")
        with gr.Row():
            output_local = gr.Markdown(label="summary", height=1000)
    sumlocal_btn.click(summarize_pdf, inputs=[input_path, set_max_length], outputs=output_local)
    with gr.Tab(label="ChatBot"):
        gr.ChatInterface(chat_with_replit,
                       examples=[
                           "Explain about the attention is all you need",
                           "Who is the inventor of the GAN",
                           "What is the main idea style transfer?"
                       ])
    with gr.Tab(label="Chat with pdf"):
        gr.ChatInterface(fn = chat_with_replit_pdf,
                         additional_inputs = [
                             gr.Textbox(label="doi", placeholder="Enter doi number")
                         ],
                        type="messages")
app.launch()