import gradio as gr
import pandas as pd
import numpy as np
import os
import time
import re
import json
from auditqa.sample_questions import QUESTIONS
from auditqa.reports import POSSIBLE_REPORTS
from auditqa.engine.prompts import audience_prompts, answer_prompt_template
from auditqa.doc_process import process_pdf
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.output_parsers import StrOutputParser
from langchain.llms import HuggingFaceEndpoint
from dotenv import load_dotenv
load_dotenv()
HF_token = os.environ["HF_TOKEN"]
vectorstores = process_pdf()
async def chat(query,history,sources,reports):
    """taking a query and a message history, use a pipeline (reformulation, retriever, answering) to yield a tuple of:
    (messages in gradio format, messages in langchain format, source documents)"""
    print(f">> NEW QUESTION : {query}")
    print(f"history:{history}")
    #print(f"audience:{audience}")
    print(f"sources:{sources}")
    print(f"reports:{reports}")
    docs_html = ""
    output_query = ""
    output_language = "English"
    audience = "Experts"
    if audience == "Children":
        audience_prompt = audience_prompts["children"]
    elif audience == "General public":
        audience_prompt = audience_prompts["general"]
    elif audience == "Experts":
        audience_prompt = audience_prompts["experts"]
    else:
        audience_prompt = audience_prompts["experts"]
    # Prepare default values
    if len(sources) == 0:
        sources = ["Consolidated Reports"]
    if len(reports) == 0:
        reports = []
    if sources == "Ministry":
        vectorstore = vectorstores["MWTS"]
    else:
        vectorstore = vectorstores["Consolidated"]
# get context   
    context_retrieved_lst = []
    question_lst= [query]
    for question in question_lst:
        retriever = vectorstore.as_retriever(
          search_type="similarity_score_threshold", search_kwargs={"score_threshold": 0.6, "k": 3})
        context_retrieved = retriever.invoke(question)
        def format_docs(docs):
            return "\n\n".join(doc.page_content for doc in docs)
        context_retrieved_formatted = format_docs(context_retrieved)
        context_retrieved_lst.append(context_retrieved_formatted)
    
    
# get prompt
    prompt = ChatPromptTemplate.from_template(answer_prompt_template)
# get llm
    llm_qa = HuggingFaceEndpoint(
      endpoint_url= "https://mnczdhmrf7lkfd9d.eu-west-1.aws.endpoints.huggingface.cloud",
      task="text-generation",
      huggingfacehub_api_token=HF_token,
      model_kwargs={})
# trying llm new-prompt adapted for llama-3
# https://stackoverflow.com/questions/78429932/langchain-ollama-and-llama-3-prompt-and-response
# https://api.python.langchain.com/en/latest/llms/langchain_community.llms.huggingface_endpoint.HuggingFaceEndpoint.html#langchain_community.llms.huggingface_endpoint.HuggingFaceEndpoint.model_kwargs
# https://huggingface.co/blog/llama3#how-to-prompt-llama-3
    llm_qa = HuggingFaceEndpoint(
      endpoint_url= "https://mnczdhmrf7lkfd9d.eu-west-1.aws.endpoints.huggingface.cloud",
      task="text-generation",
      huggingfacehub_api_token=HF_token,
      truncate = 1500,
      stop=["<|eot_id|>"],
      streaming-True,
      model_kwargs={})
    
# create rag chain
    chain = prompt | llm_qa | StrOutputParser()
# get answers
    answer_lst = []
    for question, context in zip(question_lst , context_retrieved_lst):
        answer = chain.invoke({"context": context, "question": question,'audience':audience_prompt, 'language':'english'})
        answer_lst.append(answer)
    docs_html = []
    for i, d in enumerate(context_retrieved, 1):
        docs_html.append(make_html_source(d, i))
    docs_html = "".join(docs_html)
    previous_answer = history[-1][1]
    previous_answer = previous_answer if previous_answer is not None else ""
    answer_yet = previous_answer + answer_lst[0]
    answer_yet = parse_output_llm_with_sources(answer_yet)
    history[-1] = (query,answer_yet)
    
    history = [tuple(x) for x in history]
    
    yield history,docs_html,output_query,output_language
def make_html_source(source,i):
    meta = source.metadata
    # content = source.page_content.split(":",1)[1].strip()
    content = source.page_content.strip()
    name = meta['source']
    card = f"""
        
            
                Doc {i} - {meta['file_path']} - Page {int(meta['page'])}
                {content}
             
            
         
        """
    return card
def parse_output_llm_with_sources(output):
    # Split the content into a list of text and "[Doc X]" references
    content_parts = re.split(r'\[(Doc\s?\d+(?:,\s?Doc\s?\d+)*)\]', output)
    parts = []
    for part in content_parts:
        if part.startswith("Doc"):
            subparts = part.split(",")
            subparts = [subpart.lower().replace("doc","").strip() for subpart in subparts]
            subparts = [f"""{subpart}""" for subpart in subparts]
            parts.append("".join(subparts))
        else:
            parts.append(part)
    content_parts = "".join(parts)
    return content_parts
# --------------------------------------------------------------------
# Gradio
# --------------------------------------------------------------------
# Set up Gradio Theme
theme = gr.themes.Base(
    primary_hue="blue",
    secondary_hue="red",
    font=[gr.themes.GoogleFont("Poppins"), "ui-sans-serif", "system-ui", "sans-serif"],
)
init_prompt = """
Hello, I am Audit Q&A, a conversational assistant designed to help you understand audit Reports. I will answer your questions by **crawling through the Audit reports publishsed by Auditor General Office**.
❓ How to use
- **Examples**(tab on right): If this is first time for you using this app, then we have curated some example questions.Select a particular question from category fo questions.
- **Reports**(tab on right): You can choose to search or address your question to either specific report or a collection of reportlike Consolidated Annual Report,District or Department focused reports. If you dont select then the Consolidated report is relied upon to answer your question.
- **Sources**(tab on right): This tab will display the relied upon paragraphs from the report, to help you in assessing or fact checking if the answer provided by Audit Q&A assitant is correct or not.
⚠️ Limitations
- *Please note that the AI is not perfect and may sometimes give irrelevant answers. If you are not satisfied with the answer, please ask a more specific question or report your feedback to help us improve the system.*
-  Audit Q&A assistant is a Generative AI, and therefore is not deterministic, so there might be change in answer to same question.
What do you want to learn ?
"""
# Setting Tabs
with gr.Blocks(title="Audit Q&A", css="style.css", theme=theme,elem_id = "main-component") as demo:
    # user_id_state = gr.State([user_id])
    with gr.Tab("AuditQ&A"):
        with gr.Row(elem_id="chatbot-row"):
            with gr.Column(scale=2):
                # state = gr.State([system_template])
                chatbot = gr.Chatbot(
                    value=[(None,init_prompt)],
                    show_copy_button=True,show_label = False,elem_id="chatbot",layout = "panel",
                    avatar_images = (None,"data-collection.png"),
                )#,avatar_images = ("assets/logo4.png",None))
                
                # bot.like(vote,None,None)
                with gr.Row(elem_id = "input-message"):
                    textbox=gr.Textbox(placeholder="Ask me anything here!",show_label=False,scale=7,lines = 1,interactive = True,elem_id="input-textbox")
                    # submit = gr.Button("",elem_id = "submit-button",scale = 1,interactive = True,icon = "https://static-00.iconduck.com/assets.00/settings-icon-2048x2046-cw28eevx.png")
            with gr.Column(scale=1, variant="panel",elem_id = "right-panel"):
                with gr.Tabs() as tabs:
                    with gr.TabItem("Examples",elem_id = "tab-examples",id = 0):
                                        
                        examples_hidden = gr.Textbox(visible = False)
                        first_key = list(QUESTIONS.keys())[0]
                        dropdown_samples = gr.Dropdown(QUESTIONS.keys(),value = first_key,interactive = True,show_label = True,label = "Select a category of sample questions",elem_id = "dropdown-samples")
                        samples = []
                        for i,key in enumerate(QUESTIONS.keys()):
                            examples_visible = True if i == 0 else False
                            with gr.Row(visible = examples_visible) as group_examples:
                                examples_questions = gr.Examples(
                                    QUESTIONS[key],
                                    [examples_hidden],
                                    examples_per_page=8,
                                    run_on_click=False,
                                    elem_id=f"examples{i}",
                                    api_name=f"examples{i}",
                                    # label = "Click on the example question or enter your own",
                                    # cache_examples=True,
                                )
                            
                            samples.append(group_examples)
                    with gr.Tab("Reports",elem_id = "tab-config",id = 2):
                        gr.Markdown("Reminder: To get better results select the specific report/reports")
                        dropdown_sources = gr.Dropdown(
                            ["Consolidated Reports", "District","Ministry"],
                            label="Select source",
                            value=["Ministry"],
                            interactive=True,
                        )
                        dropdown_reports = gr.Dropdown(
                            POSSIBLE_REPORTS,
                            label="Or select specific reports",
                            multiselect=True,
                            value=None,
                            interactive=True,
                        )
                        #dropdown_audience = "Experts"
                        #dropdown_audience = gr.Dropdown(
                        #    ["Children","General public","Experts"],
                        #    label="Select audience",
                        #    value="Experts",
                        #    interactive=True,
                        #)
                        output_query = gr.Textbox(label="Query used for retrieval",show_label = True,elem_id = "reformulated-query",lines = 2,interactive = False)
                        #output_language = gr.Textbox(label="Language",show_label = True,elem_id = "language",lines = 1,interactive = False)
                        
                    with gr.Tab("Sources",elem_id = "tab-citations",id = 1):
                        sources_textbox = gr.HTML(show_label=False, elem_id="sources-textbox")
                        docs_textbox = gr.State("")
                    # with Modal(visible = False) as config_modal:
    with gr.Tab("About",elem_classes = "max-height other-tabs"):
        with gr.Row():
            with gr.Column(scale=1):
                gr.Markdown("See more info at [https://www.oag.go.ug/](https://www.oag.go.ug/welcome)")
    
    
    def start_chat(query,history):
        history = history + [(query,None)]
        history = [tuple(x) for x in history]
        return (gr.update(interactive = False),gr.update(selected=1),history)
    
    def finish_chat():
        return (gr.update(interactive = True,value = ""))
    (textbox
        .submit(start_chat, [textbox,chatbot], [textbox,tabs,chatbot],queue = False,api_name = "start_chat_textbox")
        .then(chat, [textbox,chatbot, dropdown_sources,dropdown_reports], [chatbot,sources_textbox,output_query],concurrency_limit = 8,api_name = "chat_textbox")
        .then(finish_chat, None, [textbox],api_name = "finish_chat_textbox")
    )
    (examples_hidden
        .change(start_chat, [examples_hidden,chatbot], [textbox,tabs,chatbot],queue = False,api_name = "start_chat_examples")
        .then(chat, [examples_hidden,chatbot, dropdown_sources,dropdown_reports], [chatbot,sources_textbox,output_query],concurrency_limit = 8,api_name = "chat_examples")
        .then(finish_chat, None, [textbox],api_name = "finish_chat_examples")
    )
    def change_sample_questions(key):
        index = list(QUESTIONS.keys()).index(key)
        visible_bools = [False] * len(samples)
        visible_bools[index] = True
        return [gr.update(visible=visible_bools[i]) for i in range(len(samples))]
    dropdown_samples.change(change_sample_questions,dropdown_samples,samples)
    demo.queue()
demo.launch()