import os
import gradio as gr
from langchain_community.embeddings import HuggingFaceBgeEmbeddings
from langchain_community.vectorstores import Chroma
from langchain.retrievers import MultiQueryRetriever
from langchain.chains import ConversationalRetrievalChain
from langchain.memory import ConversationBufferWindowMemory
from langchain_community.llms import llamacpp, huggingface_hub
from langchain.prompts import PromptTemplate
from langchain.chains import LLMChain
from langchain.chains.question_answering import load_qa_chain
from huggingface_hub import hf_hub_download, login
login(os.environ['hf_token'])

_template = """Given the following conversation and a follow up question, rephrase the follow up question to be a 
standalone question without changing the content in given question.
Chat History:
{chat_history}
Follow Up Input: {question}
Standalone question:"""
system_prompt = """You are a helpful assistant, you will use the provided context to answer user questions.
Read the given context before answering questions and think step by step. If you can not answer a user question based on the provided context, inform the user.
Do not use any other information for answering the user. Provide a detailed answer to the question."""

def load_llmware_model():
    return huggingface_hub.HuggingFaceHub(
        repo_id = "",
        verbose=True,
        model_kwargs={
            'temperature':0.03,
            'n_batch':128,
        }
    )
def load_quantized_model(model_id=None):
    MODEL_ID, MODEL_BASENAME = "TheBloke/zephyr-7B-beta-GGUF","zephyr-7b-beta.Q5_K_S.gguf"
    try:
        model_path = hf_hub_download(
            repo_id=MODEL_ID,
            filename=MODEL_BASENAME, 
            resume_download=True,
            cache_dir = "models"
        )
        kwargs = {
            'model_path': model_path,
            'n_ctx': 10000,
            'max_tokens': 10000,
            'n_batch': 512,
            # 'n_gpu_layers':6,
        }
        return llamacpp.LlamaCpp(**kwargs)
    except TypeError:
        print("Supported model architecture: Llama, Mistral")
        return None

def upload_files(files):
    file_paths = [file.name for file in files]
    return file_paths

with gr.Blocks() as demo:
    gr.Markdown(
    """
    <h2> <center> PrivateGPT </center> </h2>
    """)
    
    with gr.Row():
        with gr.Column(scale=1):
            with gr.Row():
                model_id = gr.Radio(["Zephyr-7b-Beta", "Llama-2-7b-chat"], value="Llama-2-7b-chat",label="LLM Model")
            with gr.Row():    
                mode = gr.Radio(['OITF Manuals', 'Operations Data'], value='OITF Manuals',label="Data")
        persist_directory = "db"
        embeddings = HuggingFaceBgeEmbeddings(
            model_name = "BAAI/bge-small-en-v1.5",
            model_kwargs={"device": "cpu"},
            encode_kwargs = {'normalize_embeddings':True},
            cache_folder="models",
        )
        db2 = Chroma(persist_directory = persist_directory,embedding_function = embeddings)
        # llm = load_quantized_model(model_id=model_id) #type:ignore
        # ---------------------------------------------------------------------------------------------------
        llm = load_quantized_model()
        llm_sm = load_llmware_model()
        # ---------------------------------------------------------------------------------------------------
        condense_question_prompt_template = PromptTemplate.from_template(_template)
        prompt_template = system_prompt + """
            {context}
            Question: {question}
            Helpful Answer:"""
        qa_prompt = PromptTemplate(template=prompt_template, input_variables=["context", "question"])
        memory = ConversationBufferWindowMemory(memory_key='chat_history', k=1, return_messages=True)
        retriever_from_llm = MultiQueryRetriever.from_llm(
                retriever=db2.as_retriever(search_kwargs={'k':5}),
                llm = llm_sm,
        )
        qa2 = ConversationalRetrievalChain(
            retriever=retriever_from_llm,
            question_generator= LLMChain(llm=llm_sm, prompt=condense_question_prompt_template, memory=memory, verbose=True), #type:ignore
            combine_docs_chain=load_qa_chain(llm=llm, chain_type="stuff", prompt=qa_prompt, verbose=True), #type:ignore
            memory=memory,
            verbose=True,
            # type: ignore
        )
        def add_text(history, text):
            history = history + [(text, None)]
            return history, ""

        def bot(history):
            res = qa2.invoke(
                {
                    'question': history[-1][0],
                    'chat_history': history[:-1]
                }
            )
            history[-1][1] = res['answer']
            # torch.cuda.empty_cache()
            return history
        with gr.Column(scale=9): # type: ignore
            with gr.Row():
                chatbot = gr.Chatbot([], elem_id="chatbot",label="Chat", height=500, show_label=True, avatar_images=["user.jpeg","Bot.jpg"])
            with gr.Row():
                with gr.Column(scale=8): # type: ignore
                    txt = gr.Textbox(
                        show_label=False,
                        placeholder="Enter text and press enter",
                        container=False,
                    )
                with gr.Column(scale=1): # type: ignore
                    submit_btn = gr.Button(
                        'Submit',
                        variant='primary'
                    )
                with gr.Column(scale=1): # type: ignore
                    clear_btn = gr.Button(
                        'Clear',
                        variant="stop"
                    )
            txt.submit(add_text, [chatbot, txt], [chatbot, txt]).then(
                bot, chatbot, chatbot
            )
            submit_btn.click(add_text, [chatbot, txt], [chatbot, txt]).then(
                bot, chatbot, chatbot
            )
            clear_btn.click(lambda: None, None, chatbot, queue=False)

if __name__ == "__main__":
    demo.queue()
    demo.launch(max_threads=8, debug=True, show_error=True)