import time
import os
from typing import Literal, Tuple
import gradio as gr
import torch
from transformers import AutoModel, AutoTokenizer
import meilisearch


tokenizer = AutoTokenizer.from_pretrained("BAAI/bge-base-en-v1.5")
model = AutoModel.from_pretrained("BAAI/bge-base-en-v1.5")
model.eval()

cuda_available = torch.cuda.is_available()
print(f"CUDA available: {cuda_available}")

meilisearch_client = meilisearch.Client(
    "https://edge.meilisearch.com", os.environ["MEILISEARCH_KEY"]
)
meilisearch_index_name = "docs-embed"
meilisearch_index = meilisearch_client.index(meilisearch_index_name)

output_options = ["RAG-friendly", "human-friendly"]


def search_embeddings(
    query_text: str, output_option: Literal["RAG-friendly", "human-friendly"]
) -> Tuple[str, str]:
    start_time_embedding = time.time()
    query_prefix = "Represent this sentence for searching code documentation: "
    query_tokens = tokenizer(
        query_prefix + query_text,
        padding=True,
        truncation=True,
        return_tensors="pt",
        max_length=512,
    )
    # step1: tokenizer the query
    with torch.no_grad():
        # Compute token embeddings
        model_output = model(**query_tokens)
        sentence_embeddings = model_output[0][:, 0]
        # normalize embeddings
        sentence_embeddings = torch.nn.functional.normalize(
            sentence_embeddings, p=2, dim=1
        )
        sentence_embeddings_list = sentence_embeddings[0].tolist()
        elapsed_time_embedding = time.time() - start_time_embedding

    # step2: search meilisearch
    start_time_meilisearch = time.time()
    response = meilisearch_index.search(
        "",
        opt_params={
            "vector": sentence_embeddings_list,
            "hybrid": {"semanticRatio": 1.0},
            "limit": 5,
            "attributesToRetrieve": [
                "text",
                "source_page_url",
                "source_page_title",
                "library",
            ],
        },
    )
    elapsed_time_meilisearch = time.time() - start_time_meilisearch
    hits = response["hits"]

    sources_md = [
        f"[\"{hit['source_page_title']}\"]({hit['source_page_url']})" for hit in hits
    ]
    sources_md = ", ".join(sources_md)

    # step3: present the results in markdown
    if output_option == "human-friendly":
        md = f"Stats:\n\nembedding time: {elapsed_time_embedding:.2f}s\n\nmeilisearch time: {elapsed_time_meilisearch:.2f}s\n\n---\n\n"
        for hit in hits:
            text, source_page_url, source_page_title = (
                hit["text"],
                hit["source_page_url"],
                hit["source_page_title"],
            )
            source = f'src: ["{source_page_title}"]({source_page_url})'
            md += text + f"\n\n{source}\n\n---\n\n"
        return md, sources_md
    elif output_option == "RAG-friendly":
        hit_texts = [hit["text"] for hit in hits]
        hit_text_str = "\n------------\n".join(hit_texts)
        return hit_text_str, sources_md


demo = gr.Interface(
    fn=search_embeddings,
    inputs=[
        gr.Textbox(
            label="enter your query", placeholder="Type Markdown here...", lines=10
        ),
        gr.Radio(
            label="Select an output option",
            choices=output_options,
            value="RAG-friendly",
        ),
    ],
    outputs=[gr.Markdown(), gr.Markdown()],
    title="HF Docs Embeddings Explorer",
    allow_flagging="never",
)

if __name__ == "__main__":
    demo.launch()