import os
os.environ["GIT_CLONE_PROTECTION_ACTIVE"] = "false"
from pathlib import Path
import requests
import shutil
import io
from pathlib import Path
import openvino as ov
import torch
from transformers import (
    TextIteratorStreamer,
    StoppingCriteria,
    StoppingCriteriaList,
)
from llm_config import (
    SUPPORTED_EMBEDDING_MODELS,
    SUPPORTED_RERANK_MODELS,
    SUPPORTED_LLM_MODELS,
)
from huggingface_hub import login


config_shared_path = Path("../../utils/llm_config.py")
config_dst_path = Path("llm_config.py")
text_example_en_path = Path("text_example_en.pdf")
text_example_cn_path = Path("text_example_cn.pdf")
text_example_en = "https://github.com/openvinotoolkit/openvino_notebooks/files/15039728/Platform.Brief_Intel.vPro.with.Intel.Core.Ultra_Final.pdf"
text_example_cn = "https://github.com/openvinotoolkit/openvino_notebooks/files/15039713/Platform.Brief_Intel.vPro.with.Intel.Core.Ultra_Final_CH.pdf"

if not config_dst_path.exists():
    if config_shared_path.exists():
        try:
            os.symlink(config_shared_path, config_dst_path)
        except Exception:
            shutil.copy(config_shared_path, config_dst_path)
    else:
        r = requests.get(url="https://raw.githubusercontent.com/openvinotoolkit/openvino_notebooks/latest/utils/llm_config.py")
        with open("llm_config.py", "w", encoding="utf-8") as f:
            f.write(r.text)
elif not os.path.islink(config_dst_path):
    print("LLM config will be updated")
    if config_shared_path.exists():
        shutil.copy(config_shared_path, config_dst_path)
    else:
        r = requests.get(url="https://raw.githubusercontent.com/openvinotoolkit/openvino_notebooks/latest/utils/llm_config.py")
        with open("llm_config.py", "w", encoding="utf-8") as f:
            f.write(r.text)


if not text_example_en_path.exists():
    r = requests.get(url=text_example_en)
    content = io.BytesIO(r.content)
    with open("text_example_en.pdf", "wb") as f:
        f.write(content.read())

if not text_example_cn_path.exists():
    r = requests.get(url=text_example_cn)
    content = io.BytesIO(r.content)
    with open("text_example_cn.pdf", "wb") as f:
        f.write(content.read())

model_language = "English"
llm_model_id = "llama-3-8b-instruct"                                                #"llama-3.2-3b-instruct"                          #"llama-3-8b-instruct"
llm_model_configuration = SUPPORTED_LLM_MODELS[model_language][llm_model_id]
print(f"Selected LLM model {llm_model_id}")
prepare_int4_model = True   # Prepare INT4 model
prepare_int8_model = False  # Do not prepare INT8 model
prepare_fp16_model = False  # Do not prepare FP16 model
enable_awq = False
# Get the token from the environment variable
hf_token = os.getenv("HUGGINGFACE_TOKEN")

if hf_token is None:
    raise ValueError(
        "HUGGINGFACE_TOKEN environment variable not set. "
        "Please set it in your environment variables or repository secrets."
    )

# Log in to Hugging Face Hub
login(token=hf_token)
pt_model_id = llm_model_configuration["model_id"]
# pt_model_name = llm_model_id.value.split("-")[0]
fp16_model_dir = Path(llm_model_id) / "FP16"
int8_model_dir = Path(llm_model_id) / "INT8_compressed_weights"
int4_model_dir = Path(llm_model_id) / "INT4_compressed_weights"


def convert_to_fp16():
    if (fp16_model_dir / "openvino_model.xml").exists():
        return
    remote_code = llm_model_configuration.get("remote_code", False)
    export_command_base = "optimum-cli export openvino --model {} --task text-generation-with-past --weight-format fp16".format(pt_model_id)
    if remote_code:
        export_command_base += " --trust-remote-code"
    export_command = export_command_base + " " + str(fp16_model_dir)



def convert_to_int8():
    if (int8_model_dir / "openvino_model.xml").exists():
        return
    int8_model_dir.mkdir(parents=True, exist_ok=True)
    remote_code = llm_model_configuration.get("remote_code", False)
    export_command_base = "optimum-cli export openvino --model {} --task text-generation-with-past --weight-format int8".format(pt_model_id)
    if remote_code:
        export_command_base += " --trust-remote-code"
    export_command = export_command_base + " " + str(int8_model_dir)



def convert_to_int4():
    compression_configs = {
        "zephyr-7b-beta": {
            "sym": True,
            "group_size": 64,
            "ratio": 0.6,
        },
        "mistral-7b": {
            "sym": True,
            "group_size": 64,
            "ratio": 0.6,
        },
        "minicpm-2b-dpo": {
            "sym": True,
            "group_size": 64,
            "ratio": 0.6,
        },
        "gemma-2b-it": {
            "sym": True,
            "group_size": 64,
            "ratio": 0.6,
        },
        "notus-7b-v1": {
            "sym": True,
            "group_size": 64,
            "ratio": 0.6,
        },
        "neural-chat-7b-v3-1": {
            "sym": True,
            "group_size": 64,
            "ratio": 0.6,
        },
        "llama-2-chat-7b": {
            "sym": True,
            "group_size": 128,
            "ratio": 0.8,
        },
        "llama-3-8b-instruct": {
            "sym": True,
            "group_size": 128,
            "ratio": 0.8,
        },
        "gemma-7b-it": {
            "sym": True,
            "group_size": 128,
            "ratio": 0.8,
        },
        "chatglm2-6b": {
            "sym": True,
            "group_size": 128,
            "ratio": 0.72,
        },
        "qwen-7b-chat": {"sym": True, "group_size": 128, "ratio": 0.6},
        "red-pajama-3b-chat": {
            "sym": False,
            "group_size": 128,
            "ratio": 0.5,
        },
        "default": {
            "sym": False,
            "group_size": 128,
            "ratio": 0.8,
        },
    }

    model_compression_params = compression_configs.get(llm_model_id, compression_configs["default"])
    if (int4_model_dir / "openvino_model.xml").exists():
        return
    remote_code = llm_model_configuration.get("remote_code", False)
    export_command_base = "optimum-cli export openvino --model {} --task text-generation-with-past --weight-format int4".format(pt_model_id)
    int4_compression_args = " --group-size {} --ratio {}".format(model_compression_params["group_size"], model_compression_params["ratio"])
    if model_compression_params["sym"]:
        int4_compression_args += " --sym"
    print("updated")
    if enable_awq:
        int4_compression_args += " --awq --dataset wikitext2 --num-samples 128"
    export_command_base += int4_compression_args
    if remote_code:
        export_command_base += " --trust-remote-code"
    # export_command = export_command_base + " " + str(int4_model_dir)



if prepare_fp16_model:
    convert_to_fp16()
if prepare_int8_model:
    convert_to_int8()
if prepare_int4_model:
    convert_to_int4()
fp16_weights = fp16_model_dir / "openvino_model.bin"
int8_weights = int8_model_dir / "openvino_model.bin"
int4_weights = int4_model_dir / "openvino_model.bin"

if fp16_weights.exists():
    print(f"Size of FP16 model is {fp16_weights.stat().st_size / 1024 / 1024:.2f} MB")
for precision, compressed_weights in zip([8, 4], [int8_weights, int4_weights]):
    if compressed_weights.exists():
        print(f"Size of model with INT{precision} compressed weights is {compressed_weights.stat().st_size / 1024 / 1024:.2f} MB")
    if compressed_weights.exists() and fp16_weights.exists():
        print(f"Compression rate for INT{precision} model: {fp16_weights.stat().st_size / compressed_weights.stat().st_size:.3f}")
embedding_model_id = 'bge-small-en-v1.5'                  #'bge-small-en-v1.5', 'bge-large-en-v1.5', 'bge-m3'), value='bge-small-en-v1.5'
embedding_model_configuration = SUPPORTED_EMBEDDING_MODELS[model_language][embedding_model_id]
print(f"Selected {embedding_model_id} model")
export_command_base = "optimum-cli export openvino --model {} --task feature-extraction".format(embedding_model_configuration["model_id"])
export_command = export_command_base + " " + str(embedding_model_id)
rerank_model_id = "bge-reranker-v2-m3"                               #'bge-reranker-v2-m3', 'bge-reranker-large', 'bge-reranker-base')
rerank_model_configuration = SUPPORTED_RERANK_MODELS[rerank_model_id]
print(f"Selected {rerank_model_id} model")
export_command_base = "optimum-cli export openvino --model {} --task text-classification".format(rerank_model_configuration["model_id"])
export_command = export_command_base + " " + str(rerank_model_id)
embedding_device = "CPU"
USING_NPU = embedding_device == "NPU"

npu_embedding_dir = embedding_model_id + "-npu"
npu_embedding_path = Path(npu_embedding_dir) / "openvino_model.xml"
if USING_NPU and not Path(npu_embedding_dir).exists():
    r = requests.get(
        url="https://raw.githubusercontent.com/openvinotoolkit/openvino_notebooks/latest/utils/notebook_utils.py",
    )
    with open("notebook_utils.py", "w") as f:
        f.write(r.text)
    import notebook_utils as utils

    shutil.copytree(embedding_model_id, npu_embedding_dir)
    utils.optimize_bge_embedding(Path(embedding_model_id) / "openvino_model.xml", npu_embedding_path)
rerank_device = "CPU"
llm_device = "CPU"
from langchain_community.embeddings import OpenVINOBgeEmbeddings

embedding_model_name = npu_embedding_dir if USING_NPU else embedding_model_id
batch_size = 1 if USING_NPU else 4
embedding_model_kwargs = {"device": embedding_device, "compile": False}
encode_kwargs = {
    "mean_pooling": embedding_model_configuration["mean_pooling"],
    "normalize_embeddings": embedding_model_configuration["normalize_embeddings"],
    "batch_size": batch_size,
}

embedding = OpenVINOBgeEmbeddings(
    model_name_or_path="BAAI/bge-small-en-v1.5",
    model_kwargs=embedding_model_kwargs,
    encode_kwargs=encode_kwargs,
)
if USING_NPU:
    embedding.ov_model.reshape(1, 512)
embedding.ov_model.compile()

text = "This is a test document."
embedding_result = embedding.embed_query(text)
print(embedding_result[:3])
from langchain_community.document_compressors.openvino_rerank import OpenVINOReranker

rerank_model_name = rerank_model_id
rerank_model_kwargs = {"device": rerank_device}
rerank_top_n = 2

reranker = OpenVINOReranker(
    model_name_or_path="BAAI/bge-reranker-v2-m3",
    model_kwargs=rerank_model_kwargs,
    top_n=rerank_top_n,
)
model_to_run = "INT4"
from langchain_community.llms.huggingface_pipeline import HuggingFacePipeline

if model_to_run == "INT4":
    model_dir = int4_model_dir
elif model_to_run == "INT8":
    model_dir = int8_model_dir
else:
    model_dir = fp16_model_dir
print(f"Loading model from {model_dir}")

ov_config = {"PERFORMANCE_HINT": "LATENCY", "NUM_STREAMS": "1", "CACHE_DIR": ""}

print("starting setting llm model")
llm = HuggingFacePipeline.from_model_id(
    model_id="meta-llama/Meta-Llama-3-8B",
    task="text-generation",
    backend="openvino",
    model_kwargs={
        "device": llm_device,
        "ov_config": ov_config,
        "trust_remote_code": True,
    },
    pipeline_kwargs={"max_new_tokens": 2},
)

print(llm.invoke("2 + 2 ="))
# from optimum.intel.openvino import OVModelForCausalLM
# from transformers import pipeline


# model_id = "meta-llama/Meta-Llama-3-8B"
# ov_config = {"PERFORMANCE_HINT": "LATENCY"}  # 这是一个例子,检查你的实际 ov_config

# # 使用 OpenVINO 导出模型
# model = OVModelForCausalLM.from_pretrained(
#     model_id,
#     export=True,  # 将模型转换为 OpenVINO 格式
#     use_cache=False,
#     ov_config=ov_config,
#     trust_remote_code=True  # 支持远程代码的信任问题
# )

# # 保存 OpenVINO 模型
# model.save_pretrained("./openvino_llama_model")

# # Step 2: 加载保存的 OpenVINO 模型并设置推理任务
# llm_device = "CPU"  # 确保你根据环境设置正确的设备
# llm = pipeline(
#     task="text-generation",
#     model=OVModelForCausalLM.from_pretrained("./openvino_llama_model"),
#     device=llm_device,
#     max_new_tokens=2  # 生成的最大新token数量
# )

# # Step 3: 执行推理
# output = llm("2 + 2 =")
# print(output)

# print("test:2+2:")
# print(llm.invoke("2 + 2 ="))
import re
from typing import List
from langchain.text_splitter import (
    CharacterTextSplitter,
    RecursiveCharacterTextSplitter,
    MarkdownTextSplitter,
)
from langchain.document_loaders import (
    CSVLoader,
    EverNoteLoader,
    PyPDFLoader,
    TextLoader,
    UnstructuredEPubLoader,
    UnstructuredHTMLLoader,
    UnstructuredMarkdownLoader,
    UnstructuredODTLoader,
    UnstructuredPowerPointLoader,
    UnstructuredWordDocumentLoader,
)


class ChineseTextSplitter(CharacterTextSplitter):
    def __init__(self, pdf: bool = False, **kwargs):
        super().__init__(**kwargs)
        self.pdf = pdf

    def split_text(self, text: str) -> List[str]:
        if self.pdf:
            text = re.sub(r"\n{3,}", "\n", text)
            text = text.replace("\n\n", "")
        sent_sep_pattern = re.compile('([﹒﹔﹖﹗.。!?]["’”」』]{0,2}|(?=["‘“「『]{1,2}|$))')
        sent_list = []
        for ele in sent_sep_pattern.split(text):
            if sent_sep_pattern.match(ele) and sent_list:
                sent_list[-1] += ele
            elif ele:
                sent_list.append(ele)
        return sent_list


TEXT_SPLITERS = {
    "Character": CharacterTextSplitter,
    "RecursiveCharacter": RecursiveCharacterTextSplitter,
    "Markdown": MarkdownTextSplitter,
    "Chinese": ChineseTextSplitter,
}


LOADERS = {
    ".csv": (CSVLoader, {}),
    ".doc": (UnstructuredWordDocumentLoader, {}),
    ".docx": (UnstructuredWordDocumentLoader, {}),
    ".enex": (EverNoteLoader, {}),
    ".epub": (UnstructuredEPubLoader, {}),
    ".html": (UnstructuredHTMLLoader, {}),
    ".md": (UnstructuredMarkdownLoader, {}),
    ".odt": (UnstructuredODTLoader, {}),
    ".pdf": (PyPDFLoader, {}),
    ".ppt": (UnstructuredPowerPointLoader, {}),
    ".pptx": (UnstructuredPowerPointLoader, {}),
    ".txt": (TextLoader, {"encoding": "utf8"}),
}

chinese_examples = [
    ["英特尔®酷睿™ Ultra处理器可以降低多少功耗?"],
    ["相比英特尔之前的移动处理器产品,英特尔®酷睿™ Ultra处理器的AI推理性能提升了多少?"],
    ["英特尔博锐® Enterprise系统提供哪些功能?"],
]

english_examples = [
    ["How much power consumption can Intel® Core™ Ultra Processors help save?"],
    ["Compared to Intel’s previous mobile processor, what is the advantage of Intel® Core™ Ultra Processors for Artificial Intelligence?"],
    ["What can Intel vPro® Enterprise systems offer?"],
]

if model_language == "English":
    # text_example_path = "text_example_en.pdf"
    text_example_path = ['Supervisors-Guide-Accurate-Timekeeping_AH edits.docx','Salary-vs-Hourly-Guide_AH edits.docx','Employee-Guide-Accurate-Timekeeping_AH edits.docx','Eller Overtime Guidelines.docx','Eller FLSA information 9.2024_AH edits.docx','Accurate Timekeeping Supervisors 12.2.20_AH edits.docx']
else:
    text_example_path = "text_example_cn.pdf"

examples = chinese_examples if (model_language == "Chinese") else english_examples
from langchain.prompts import PromptTemplate
from langchain_community.vectorstores import FAISS
from langchain.chains.retrieval import create_retrieval_chain
from langchain.chains.combine_documents import create_stuff_documents_chain
from langchain.docstore.document import Document
from langchain.retrievers import ContextualCompressionRetriever
from threading import Thread
import gradio as gr

stop_tokens = llm_model_configuration.get("stop_tokens")
rag_prompt_template = llm_model_configuration["rag_prompt_template"]


class StopOnTokens(StoppingCriteria):
    def __init__(self, token_ids):
        self.token_ids = token_ids

    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
        for stop_id in self.token_ids:
            if input_ids[0][-1] == stop_id:
                return True
        return False


if stop_tokens is not None:
    if isinstance(stop_tokens[0], str):
        stop_tokens = llm.pipeline.tokenizer.convert_tokens_to_ids(stop_tokens)

    stop_tokens = [StopOnTokens(stop_tokens)]


def load_single_document(file_path: str) -> List[Document]:
    """
    helper for loading a single document

    Params:
      file_path: document path
    Returns:
      documents loaded

    """
    ext = "." + file_path.rsplit(".", 1)[-1]
    if ext in LOADERS:
        loader_class, loader_args = LOADERS[ext]
        loader = loader_class(file_path, **loader_args)
        return loader.load()

    raise ValueError(f"File does not exist '{ext}'")


def default_partial_text_processor(partial_text: str, new_text: str):
    """
    helper for updating partially generated answer, used by default

    Params:
      partial_text: text buffer for storing previosly generated text
      new_text: text update for the current step
    Returns:
      updated text string

    """
    partial_text += new_text
    return partial_text


text_processor = llm_model_configuration.get("partial_text_processor", default_partial_text_processor)


def create_vectordb(
    docs, spliter_name, chunk_size, chunk_overlap, vector_search_top_k, vector_rerank_top_n, run_rerank, search_method, score_threshold, progress=gr.Progress()
):
    """
    Initialize a vector database

    Params:
      doc: orignal documents provided by user
      spliter_name: spliter method
      chunk_size:  size of a single sentence chunk
      chunk_overlap: overlap size between 2 chunks
      vector_search_top_k: Vector search top k
      vector_rerank_top_n: Search rerank top n
      run_rerank: whether run reranker
      search_method: top k search method
      score_threshold: score threshold when selecting 'similarity_score_threshold' method

    """
    global db
    global retriever
    global combine_docs_chain
    global rag_chain

    if vector_rerank_top_n > vector_search_top_k:
        gr.Warning("Search top k must >= Rerank top n")

    documents = []
    for doc in docs:
        if type(doc) is not str:
            doc = doc.name
        documents.extend(load_single_document(doc))

    text_splitter = TEXT_SPLITERS[spliter_name](chunk_size=chunk_size, chunk_overlap=chunk_overlap)

    texts = text_splitter.split_documents(documents)
    db = FAISS.from_documents(texts, embedding)
    if search_method == "similarity_score_threshold":
        search_kwargs = {"k": vector_search_top_k, "score_threshold": score_threshold}
    else:
        search_kwargs = {"k": vector_search_top_k}
    retriever = db.as_retriever(search_kwargs=search_kwargs, search_type=search_method)
    if run_rerank:
        reranker.top_n = vector_rerank_top_n
        retriever = ContextualCompressionRetriever(base_compressor=reranker, base_retriever=retriever)
    prompt = PromptTemplate.from_template(rag_prompt_template)
    combine_docs_chain = create_stuff_documents_chain(llm, prompt)

    rag_chain = create_retrieval_chain(retriever, combine_docs_chain)

    return "Vector database is Ready"


def update_retriever(vector_search_top_k, vector_rerank_top_n, run_rerank, search_method, score_threshold):
    """
    Update retriever

    Params:
      vector_search_top_k: Vector search top k
      vector_rerank_top_n: Search rerank top n
      run_rerank: whether run reranker
      search_method: top k search method
      score_threshold: score threshold when selecting 'similarity_score_threshold' method

    """
    global db
    global retriever
    global combine_docs_chain
    global rag_chain

    if vector_rerank_top_n > vector_search_top_k:
        gr.Warning("Search top k must >= Rerank top n")

    if search_method == "similarity_score_threshold":
        search_kwargs = {"k": vector_search_top_k, "score_threshold": score_threshold}
    else:
        search_kwargs = {"k": vector_search_top_k}
    retriever = db.as_retriever(search_kwargs=search_kwargs, search_type=search_method)
    if run_rerank:
        retriever = ContextualCompressionRetriever(base_compressor=reranker, base_retriever=retriever)
        reranker.top_n = vector_rerank_top_n
    rag_chain = create_retrieval_chain(retriever, combine_docs_chain)

    return "Vector database is Ready"


def user(message, history):
    """
    callback function for updating user messages in interface on submit button click

    Params:
      message: current message
      history: conversation history
    Returns:
      None
    """
    # Append the user's message to the conversation history
    return "", history + [[message, ""]]


def bot(history, temperature, top_p, top_k, repetition_penalty, hide_full_prompt, do_rag):
    """
    callback function for running chatbot on submit button click

    Params:
      history: conversation history
      temperature:  parameter for control the level of creativity in AI-generated text.
                    By adjusting the `temperature`, you can influence the AI model's probability distribution, making the text more focused or diverse.
      top_p: parameter for control the range of tokens considered by the AI model based on their cumulative probability.
      top_k: parameter for control the range of tokens considered by the AI model based on their cumulative probability, selecting number of tokens with highest probability.
      repetition_penalty: parameter for penalizing tokens based on how frequently they occur in the text.
      hide_full_prompt: whether to show searching results in promopt.
      do_rag: whether do RAG when generating texts.

    """
    streamer = TextIteratorStreamer(
        llm.pipeline.tokenizer,
        timeout=60.0,
        skip_prompt=hide_full_prompt,
        skip_special_tokens=True,
    )
    llm.pipeline._forward_params = dict(
        max_new_tokens=512,
        temperature=temperature,
        do_sample=temperature > 0.0,
        top_p=top_p,
        top_k=top_k,
        repetition_penalty=repetition_penalty,
        streamer=streamer,
    )
    if stop_tokens is not None:
        llm.pipeline._forward_params["stopping_criteria"] = StoppingCriteriaList(stop_tokens)

    if do_rag:
        t1 = Thread(target=rag_chain.invoke, args=({"input": history[-1][0]},))
        # input_text = history[-1][0]
        # response = llm.invoke(input_text)
        # print(response)
    else:
        input_text = rag_prompt_template.format(input=history[-1][0], context="")
        t1 = Thread(target=llm.invoke, args=(input_text,))
        # # input_text = history[-1][0]
        # response = llm.invoke(input_text)
        # print(response)
    t1.start()

    # Initialize an empty string to store the generated text
    partial_text = ""
    for new_text in streamer:
        partial_text = text_processor(partial_text, new_text)
        history[-1][1] = partial_text
    # history[-1][1] = response
        yield history


def request_cancel():
    llm.pipeline.model.request.cancel()


def clear_files():
    return "Vector Store is Not ready"


# initialize the vector store with example document
create_vectordb(
    text_example_path,  #changed
    "RecursiveCharacter",
    chunk_size=400,
    chunk_overlap=50,
    vector_search_top_k=10,
    vector_rerank_top_n=2,
    run_rerank=True,
    search_method="similarity_score_threshold",
    score_threshold=0.5,
)
with gr.Blocks(
    theme=gr.themes.Soft(),
    css=".disclaimer {font-variant-caps: all-small-caps;}",
) as demo:
    gr.Markdown("""<h1><center>QA over Document</center></h1>""")
    gr.Markdown(f"""<center>Powered by OpenVINO and {llm_model_id} </center>""")
    with gr.Row():
        with gr.Column(scale=1):
            docs = gr.File(
                label="Step 1: Load text files",
                value=text_example_path,    #changed
                file_count="multiple",
                file_types=[
                    ".csv",
                    ".doc",
                    ".docx",
                    ".enex",
                    ".epub",
                    ".html",
                    ".md",
                    ".odt",
                    ".pdf",
                    ".ppt",
                    ".pptx",
                    ".txt",
                ],
            )
            load_docs = gr.Button("Step 2: Build Vector Store", variant="primary")
            db_argument = gr.Accordion("Vector Store Configuration", open=False)
            with db_argument:
                spliter = gr.Dropdown(
                    ["Character", "RecursiveCharacter", "Markdown", "Chinese"],
                    value="RecursiveCharacter",
                    label="Text Spliter",
                    info="Method used to splite the documents",
                    multiselect=False,
                )

                chunk_size = gr.Slider(
                    label="Chunk size",
                    value=400,
                    minimum=50,
                    maximum=2000,
                    step=50,
                    interactive=True,
                    info="Size of sentence chunk",
                )

                chunk_overlap = gr.Slider(
                    label="Chunk overlap",
                    value=50,
                    minimum=0,
                    maximum=400,
                    step=10,
                    interactive=True,
                    info=("Overlap between 2 chunks"),
                )

            langchain_status = gr.Textbox(
                label="Vector Store Status",
                value="Vector Store is Ready",
                interactive=False,
            )
            do_rag = gr.Checkbox(
                value=True,
                label="RAG is ON",
                interactive=True,
                info="Whether to do RAG for generation",
            )
            with gr.Accordion("Generation Configuration", open=False):
                with gr.Row():
                    with gr.Column():
                        with gr.Row():
                            temperature = gr.Slider(
                                label="Temperature",
                                value=0.1,
                                minimum=0.0,
                                maximum=1.0,
                                step=0.1,
                                interactive=True,
                                info="Higher values produce more diverse outputs",
                            )
                    with gr.Column():
                        with gr.Row():
                            top_p = gr.Slider(
                                label="Top-p (nucleus sampling)",
                                value=1.0,
                                minimum=0.0,
                                maximum=1,
                                step=0.01,
                                interactive=True,
                                info=(
                                    "Sample from the smallest possible set of tokens whose cumulative probability "
                                    "exceeds top_p. Set to 1 to disable and sample from all tokens."
                                ),
                            )
                    with gr.Column():
                        with gr.Row():
                            top_k = gr.Slider(
                                label="Top-k",
                                value=50,
                                minimum=0.0,
                                maximum=200,
                                step=1,
                                interactive=True,
                                info="Sample from a shortlist of top-k tokens — 0 to disable and sample from all tokens.",
                            )
                    with gr.Column():
                        with gr.Row():
                            repetition_penalty = gr.Slider(
                                label="Repetition Penalty",
                                value=1.1,
                                minimum=1.0,
                                maximum=2.0,
                                step=0.1,
                                interactive=True,
                                info="Penalize repetition — 1.0 to disable.",
                            )
        with gr.Column(scale=4):
            chatbot = gr.Chatbot(
                height=800,
                label="Step 3: Input Query",
            )
            with gr.Row():
                with gr.Column():
                    with gr.Row():
                        msg = gr.Textbox(
                            label="QA Message Box",
                            placeholder="Chat Message Box",
                            show_label=False,
                            container=False,
                        )
                with gr.Column():
                    with gr.Row():
                        submit = gr.Button("Submit", variant="primary")
                        stop = gr.Button("Stop")
                        clear = gr.Button("Clear")
            gr.Examples(examples, inputs=msg, label="Click on any example and press the 'Submit' button")
            retriever_argument = gr.Accordion("Retriever Configuration", open=True)
            with retriever_argument:
                with gr.Row():
                    with gr.Row():
                        do_rerank = gr.Checkbox(
                            value=True,
                            label="Rerank searching result",
                            interactive=True,
                        )
                        hide_context = gr.Checkbox(
                            value=True,
                            label="Hide searching result in prompt",
                            interactive=True,
                        )
                    with gr.Row():
                        search_method = gr.Dropdown(
                            ["similarity_score_threshold", "similarity", "mmr"],
                            value="similarity_score_threshold",
                            label="Searching Method",
                            info="Method used to search vector store",
                            multiselect=False,
                            interactive=True,
                        )
                    with gr.Row():
                        score_threshold = gr.Slider(
                            0.01,
                            0.99,
                            value=0.5,
                            step=0.01,
                            label="Similarity Threshold",
                            info="Only working for 'similarity score threshold' method",
                            interactive=True,
                        )
                    with gr.Row():
                        vector_rerank_top_n = gr.Slider(
                            1,
                            10,
                            value=2,
                            step=1,
                            label="Rerank top n",
                            info="Number of rerank results",
                            interactive=True,
                        )
                    with gr.Row():
                        vector_search_top_k = gr.Slider(
                            1,
                            50,
                            value=10,
                            step=1,
                            label="Search top k",
                            info="Search top k must >= Rerank top n",
                            interactive=True,
                        )
    docs.clear(clear_files, outputs=[langchain_status], queue=False)
    load_docs.click(
        create_vectordb,
        inputs=[docs, spliter, chunk_size, chunk_overlap, vector_search_top_k, vector_rerank_top_n, do_rerank, search_method, score_threshold],
        outputs=[langchain_status],
        queue=False,
    )
    submit_event = msg.submit(user, [msg, chatbot], [msg, chatbot], queue=False).then(
        bot,
        [chatbot, temperature, top_p, top_k, repetition_penalty, hide_context, do_rag],
        chatbot,
        queue=True,
    )
    submit_click_event = submit.click(user, [msg, chatbot], [msg, chatbot], queue=False).then(
        bot,
        [chatbot, temperature, top_p, top_k, repetition_penalty, hide_context, do_rag],
        chatbot,
        queue=True,
    )
    stop.click(
        fn=request_cancel,
        inputs=None,
        outputs=None,
        cancels=[submit_event, submit_click_event],
        queue=False,
    )
    clear.click(lambda: None, None, chatbot, queue=False)
    vector_search_top_k.release(
        update_retriever,
        [vector_search_top_k, vector_rerank_top_n, do_rerank, search_method, score_threshold],
        outputs=[langchain_status],
    )
    vector_rerank_top_n.release(
        update_retriever,
        inputs=[vector_search_top_k, vector_rerank_top_n, do_rerank, search_method, score_threshold],
        outputs=[langchain_status],
    )
    do_rerank.change(
        update_retriever,
        inputs=[vector_search_top_k, vector_rerank_top_n, do_rerank, search_method, score_threshold],
        outputs=[langchain_status],
    )
    search_method.change(
        update_retriever,
        inputs=[vector_search_top_k, vector_rerank_top_n, do_rerank, search_method, score_threshold],
        outputs=[langchain_status],
    )
    score_threshold.change(
        update_retriever,
        inputs=[vector_search_top_k, vector_rerank_top_n, do_rerank, search_method, score_threshold],
        outputs=[langchain_status],
    )


demo.queue()
# if you are launching remotely, specify server_name and server_port
# demo.launch(server_port=8082)
# if you have any issue to launch on your platform, you can pass share=True to launch method:
demo.launch()
# it creates a publicly shareable link for the interface. Read more in the docs: https://gradio.app/docs/
# demo.launch()