import gradio as gr
import spaces
import subprocess
import os
import shutil
import string
import random
import glob
from pypdf import PdfReader
from sentence_transformers import SentenceTransformer

model_name = os.environ.get("MODEL", "Snowflake/snowflake-arctic-embed-m")
chunk_size = int(os.environ.get("CHUNK_SIZE", 128))
default_max_characters = int(os.environ.get("DEFAULT_MAX_CHARACTERS", 258))

model = SentenceTransformer(model_name)
# model.to(device="cuda")

@spaces.GPU
def embed(queries, chunks) -> dict[str, list[tuple[str, float]]]:
    query_embeddings = model.encode(queries, prompt_name="query")
    document_embeddings = model.encode(chunks)

    scores = query_embeddings @ document_embeddings.T
    results = {}
    for query, query_scores in zip(queries, scores):
        chunk_idxs = [i for i in range(len(chunks))]
        # Get a structure like {query: [(chunk_idx, score), (chunk_idx, score), ...]}
        results[query] = list(zip(chunk_idxs, query_scores))

    return results


def extract_text_from_pdf(reader):
    full_text = ""
    for idx, page in enumerate(reader.pages):
        text = page.extract_text()
        if len(text) > 0:
            full_text += f"---- Page {idx} ----\n" + page.extract_text() + "\n\n"

    return full_text.strip()

def convert(filename) -> str:
    plain_text_filetypes = [
        ".txt",
        ".csv",
        ".tsv",
        ".md",
        ".yaml",
        ".toml",
        ".json",
        ".json5",
        ".jsonc",
    ]
    # Already a plain text file that wouldn't benefit from pandoc so return the content
    if any(filename.endswith(ft) for ft in plain_text_filetypes):
        with open(filename, "r") as f:
            return f.read()

    if filename.endswith(".pdf"):
        return extract_text_from_pdf(PdfReader(filename))

    raise ValueError(f"Unsupported file type: {filename}")


def chunk_to_length(text, max_length=512):
    chunks = []
    while len(text) > max_length:
        chunks.append(text[:max_length])
        text = text[max_length:]
    chunks.append(text)
    return chunks

@spaces.GPU
def predict(query, max_characters) -> str:
    # Embed the query
    query_embedding = model.encode(query, prompt_name="query")

    # Initialize a list to store all chunks and their similarities across all documents
    all_chunks = []

    # Iterate through all documents
    for filename, doc in docs.items():
        # Calculate dot product between query and document embeddings
        similarities = doc["embeddings"] @ query_embedding.T
        
        # Add chunks and similarities to the all_chunks list
        all_chunks.extend([(filename, chunk, sim) for chunk, sim in zip(doc["chunks"], similarities)])

    # Sort all chunks by similarity
    all_chunks.sort(key=lambda x: x[2], reverse=True)

    # Initialize a dictionary to store relevant chunks for each document
    relevant_chunks = {}

    # Add most relevant chunks until max_characters is reached
    total_chars = 0
    for filename, chunk, _ in all_chunks:
        if total_chars + len(chunk) <= max_characters:
            if filename not in relevant_chunks:
                relevant_chunks[filename] = []
            relevant_chunks[filename].append(chunk)
            total_chars += len(chunk)
        else:
            break

    return relevant_chunks



docs = {}

for filename in glob.glob("sources/*"):
    if filename.endswith("add_your_files_here"):
        continue

    converted_doc = convert(filename)

    chunks = chunk_to_length(converted_doc, chunk_size)
    embeddings = model.encode(chunks)

    docs[filename] = {
        "chunks": chunks,
        "embeddings": embeddings,
    }


gr.Interface(
    predict,
    theme="Nymbo/Nymbo_Theme",
    inputs=[
        gr.Textbox(label="Query asked about the documents"),
        gr.Number(label="Max output characters", value=default_max_characters),
    ],
    outputs=[gr.JSON(label="Relevant chunks")],
    title="Hugging Chat RAG Tool",
).launch()