import stat
import gradio as gr
from llama_index.core.postprocessor import SimilarityPostprocessor
from llama_index.core.postprocessor import SentenceTransformerRerank
from llama_index.core.postprocessor import MetadataReplacementPostProcessor
from llama_index.core import StorageContext
import chromadb
from llama_index.vector_stores.chroma import ChromaVectorStore
import zipfile
import requests
import torch
from llama_index.core import Settings
from llama_index.llms.huggingface import HuggingFaceLLM
from llama_index.core import VectorStoreIndex, SimpleDirectoryReader
import sys
import logging
import os


enable_rerank = True
# sentence_window,naive,recursive_retrieval
retrieval_strategy = "sentence_window"
base_embedding_source = "hf"  # local,openai,hf
# intfloat/multilingual-e5-small local:BAAI/bge-small-en-v1.5 text-embedding-3-small nvidia/NV-Embed-v2 Alibaba-NLP/gte-large-en-v1.5
base_embedding_model = "Alibaba-NLP/gte-large-en-v1.5"
# meta-llama/Llama-3.1-8B meta-llama/Llama-3.2-3B-Instruct meta-llama/Llama-2-7b-chat-hf google/gemma-2-9b CohereForAI/c4ai-command-r-plus CohereForAI/aya-23-8B
base_llm_model = "mistralai/Mistral-7B-Instruct-v0.3"
# AdaptLLM/finance-chat
base_llm_source = "hf"  # cohere,hf,anthropic
base_similarity_top_k = 20


# ChromaDB
env_extension = "_large"  # _large _dev_window _large_window
db_collection = f"gte{env_extension}"  # intfloat gte
read_db = True
active_chroma = True
root_path = "."
chroma_db_path = f"{root_path}/chroma_db"  # ./chroma_db
# ./processed_files.json
processed_files_log = f"{root_path}/processed_files{env_extension}.json"


# check hyperparameter
if retrieval_strategy not in ["sentence_window", "naive"]:  # recursive_retrieval
    raise Exception(f"{retrieval_strategy} retrieval_strategy is not support")


os.environ["OPENAI_API_KEY"] = 'sk-xxxxxxxxxx'
hf_api_key = os.getenv("HF_API_KEY")

logging.basicConfig(stream=sys.stdout, level=logging.INFO)
logging.getLogger().addHandler(logging.StreamHandler(stream=sys.stdout))


torch.cuda.empty_cache()

os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'

print(f"loading embedding ..{base_embedding_model}")
if base_embedding_source == 'hf':
    from llama_index.embeddings.huggingface import HuggingFaceEmbedding
    Settings.embed_model = HuggingFaceEmbedding(
        model_name=base_embedding_model, trust_remote_code=True)  # ,
else:
    raise Exception("embedding model is invalid")

# setup prompts - specific to StableLM
if base_llm_source == 'hf':
    from llama_index.core import PromptTemplate

    # This will wrap the default prompts that are internal to llama-index
    # taken from https://huggingface.co/Writer/camel-5b-hf
    query_wrapper_prompt = PromptTemplate(
        "Below is an instruction that describes a task. "
        "you need to make sure that user's question and retrived context mention the same stock symbol if not please give no answer to user"
        "Write a response that appropriately completes the request.\n\n"
        "### Instruction:\n{query_str}\n\n### Response:"
    )

if base_llm_source == 'hf':
    llm = HuggingFaceLLM(
        context_window=2048,
        max_new_tokens=512,  # 256
        generate_kwargs={"temperature": 0.1, "do_sample": False},  # 0.25
        query_wrapper_prompt=query_wrapper_prompt,
        tokenizer_name=base_llm_model,
        model_name=base_llm_model,
        device_map="auto",
        tokenizer_kwargs={"max_length": 2048},
        # uncomment this if using CUDA to reduce memory usage
        model_kwargs={"torch_dtype": torch.float16}
    )

    Settings.chunk_size = 512
    Settings.llm = llm

"""#### Load documents, build the VectorStoreIndex"""


def download_and_extract_chroma_db(url, destination):
    """Download and extract ChromaDB from Hugging Face Datasets."""
    # Create destination folder if it doesn't exist
    if not os.path.exists(destination):
        os.makedirs(destination)
    else:
        # If the folder exists, remove it to ensure a fresh extract
        print("Destination folder exists. Removing it...")
        for root, dirs, files in os.walk(destination, topdown=False):
            for file in files:
                os.remove(os.path.join(root, file))
            for dir in dirs:
                os.rmdir(os.path.join(root, dir))
        print("Destination folder cleared.")

    db_zip_path = os.path.join(destination, "chroma_db.zip")
    if not os.path.exists(db_zip_path):
        # Download the ChromaDB zip file
        print("Downloading ChromaDB from Hugging Face Datasets...")
        headers = {
            "Authorization": f"Bearer {hf_api_key}"
        }
        response = requests.get(url, headers=headers, stream=True)
        response.raise_for_status()
        with open(db_zip_path, "wb") as f:
            for chunk in response.iter_content(chunk_size=8192):
                f.write(chunk)
        print("Download completed.")
    else:
        print("Zip file already exists, skipping download.")

    # Extract the zip file
    print("Extracting ChromaDB...")
    with zipfile.ZipFile(db_zip_path, 'r') as zip_ref:
        zip_ref.extractall(destination)
    print("Extraction completed. Zip file retained.")


# URL to your dataset hosted on Hugging Face
chroma_db_url = "https://huggingface.co/datasets/iamboolean/set50-db/resolve/main/chroma_db.zip"

# Local destination for the ChromaDB
chroma_db_path_extract = "./"  # You can change this to your desired path

# Download and extract the ChromaDB
download_and_extract_chroma_db(chroma_db_url, chroma_db_path_extract)

# Define ChromaDB client (persistent mode)er
db = chromadb.PersistentClient(path=chroma_db_path)
print(f"db path:{chroma_db_path}")
chroma_collection = db.get_or_create_collection(db_collection)
print(f"db collection:{db_collection}")


# Set up ChromaVectorStore and embeddings
vector_store = ChromaVectorStore(chroma_collection=chroma_collection)
storage_context = StorageContext.from_defaults(vector_store=vector_store)

document_count = chroma_collection.count()
print(f"Total documents in the collection: {document_count}")

index = VectorStoreIndex.from_vector_store(
    vector_store=vector_store,
    # embed_model=embed_model,
)

"""#### Query Index"""


rerank = SentenceTransformerRerank(
    model="cross-encoder/ms-marco-MiniLM-L-2-v2", top_n=10
)
node_postprocessors = []
# node_postprocessors.append(SimilarityPostprocessor(similarity_cutoff=0.6))

if retrieval_strategy == 'sentence_window':
    node_postprocessors.append(
        MetadataReplacementPostProcessor(target_metadata_key="window"))


if enable_rerank:
    node_postprocessors.append(rerank)


query_engine = index.as_query_engine(
    similarity_top_k=base_similarity_top_k,
    # the target key defaults to `window` to match the node_parser's default
    node_postprocessors=node_postprocessors,
)


def metadata_formatter(metadata):
    company_symbol = metadata['file_name'].split(
        '-')[0]  # Split at '-' and take the first part
    # Split at '-' and then '.' to extract the year
    year = metadata['file_name'].split('-')[1].split('.')[0]
    page_number = metadata['page_label']

    return f"Company File: {metadata['file_name'].split('-')[0]}, Year: {metadata['file_name'].split('-')[1].split('.')[0]}, Page Number: {metadata['page_label']}"


def query_journal(question):

    response = query_engine.query(question)  # Query the index
    matched_nodes = response.source_nodes  # Extract matched nodes

    # Prepare the matched nodes details
    retrieved_context = "\n".join([
        # f"Node ID: {node.node_id}\n"
        # f"Matched Content: {node.node.text}\n"
        # f"Metadata: {node.node.metadata if node.node.metadata else 'None'}"
        f"Metadata: {metadata_formatter(node.node.metadata) if node.node.metadata else 'None'}"
        for node in matched_nodes
    ])

    generated_answer = str(response)

    # Return both retrieved context and detailed matched nodes
    return retrieved_context, generated_answer


# Define the Gradio interface
with gr.Blocks() as app:
    # Title
    gr.Markdown(
        """
        <div style="text-align: center;">
            <h1>SET50RAG: Retrieval-Augmented Generation for Thai Public Companies Question Answering</h1>
        </div>
        """
    )

    # Description
    gr.Markdown(
        """
        The **SET50RAG** tool provides an interactive way to analyze and extract insights from **243 annual reports** of Thai public companies spanning **5 years**.
        By leveraging advanced **Retrieval-Augmented Generation**, including **GTE-Large embedding models**, **Sentence Window with Reranking**, and powerful **Large Language Models (LLMs)** like **Mistral-7B**, the system efficiently retrieves and answers complex financial queries.
        This scalable and cost-effective solution reduces reliance on parametric knowledge, ensuring contextually accurate and relevant responses.
        """
    )

    # How to Use Section
    gr.Markdown(
        """
        ### How to Use
        1. Type your question in the box or select an example question below.
        2. Click **Submit** to retrieve the context and get an AI-generated answer.
        3. Review the retrieved context and the generated answer to gain insights.
        ---
        """
    )

    # Example Questions Section
    gr.Markdown(
        """
        ### Example Questions
        - What is the revenue of PTTOR in 2022?
        - what is effect of COVID-19 on BDMS show me in Timeline format from 2019 to 2023?
        - How does CPALL plan for electric vehicles?
        """
    )

    # Interactive Section (RAG Box)
    with gr.Row():
        with gr.Column():
            user_question = gr.Textbox(
                label="Ask a Question",
                placeholder="Type your question here, e.g., 'What is the revenue of PTTOR in 2022?'",
            )
            example_question_button = gr.Button("Use Example Question")
        with gr.Column():
            generated_answer = gr.Textbox(
                label="Generated Answer",
                placeholder="The AI-generated answer will appear here.",
                interactive=False,
            )
            retrieved_context = gr.Textbox(
                label="Retrieved Context",
                placeholder="Relevant context will appear here.",
                interactive=False,
            )

    # Button for user interaction
    submit_button = gr.Button("Submit")

    # Example question logic
    def use_example_question():
        return "What is the revenue of PTTOR in 2022?"

    example_question_button.click(
        use_example_question, inputs=[], outputs=[user_question]
    )

    # Interaction logic for submitting user queries
    submit_button.click(
        query_journal, inputs=[user_question], outputs=[
            retrieved_context, generated_answer]
    )

    # Footer
    gr.Markdown(
        """
        ---
        ### Limitations and Bias:
        - Optimized for Thai financial reports from SET50 companies. Results may vary for other domains.
        - Retrieval and accuracy depend on data quality and embedding models.
        """
    )

# Launch the app
# app.launch()
app.launch(server_name="0.0.0.0")  # , server_port=7860