import gradio as gr
from transformers import AutoModel, AutoTokenizer
import torch
import torch.nn.functional as F

# Load embedding model and tokenizer
model_name = "Supabase/gte-small"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModel.from_pretrained(model_name)
model.eval()

def get_embedding(text):
    inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True)
    with torch.no_grad():
        output = model(**inputs)
    
    # Mean pooling over token embeddings
    embeddings = output.last_hidden_state  # Shape: (batch_size, seq_len, hidden_dim)
    attention_mask = inputs["attention_mask"].unsqueeze(-1)  # Shape: (batch_size, seq_len, 1)
    
    # Apply mean pooling: Sum(token_embeddings * mask) / Sum(mask)
    pooled_embedding = (embeddings * attention_mask).sum(dim=1) / attention_mask.sum(dim=1)

    # Normalize embedding
    return F.normalize(pooled_embedding, p=2, dim=1).squeeze()

def get_similarity_and_excerpt(query, paragraph1, paragraph2, paragraph3, threshold_weight):
    paragraphs = [p for p in [paragraph1, paragraph2, paragraph3] if p.strip()]
    
    if not query.strip() or not paragraphs:
        return "Please provide both a query and at least one document paragraph."
    
    query_embedding = get_embedding(query)
    ranked_paragraphs = []
    
    for paragraph in paragraphs:
        para_embedding = get_embedding(paragraph)
        similarity = F.cosine_similarity(query_embedding, para_embedding, dim=0).item()
        
        # Highlight words using threshold
        tokens = tokenizer.tokenize(paragraph)
        threshold = max(0.02, threshold_weight)
        highlighted_text = " ".join(f"<b>{token}</b>" if similarity > threshold else token for token in tokens)
        highlighted_text = tokenizer.convert_tokens_to_string(highlighted_text.split())
        
        ranked_paragraphs.append({"similarity": similarity, "highlighted_text": highlighted_text})
    
    ranked_paragraphs.sort(key=lambda x: x["similarity"], reverse=True)
    
    output_html = "<table border='1' style='width:100%; border-collapse: collapse;'>"
    output_html += "<tr><th>Cosine Similarity</th><th>Highlighted Paragraph</th></tr>"
    for item in ranked_paragraphs:
        output_html += f"<tr><td>{round(item['similarity'], 4)}</td><td>{item['highlighted_text']}</td></tr>"
    output_html += "</table>"
    
    return output_html

interface = gr.Interface(
    fn=get_similarity_and_excerpt,
    inputs=[
        gr.Textbox(label="Query", placeholder="Enter your search query..."),
        gr.Textbox(label="Document Paragraph 1", placeholder="Enter a paragraph to match...", lines=4),
        gr.Textbox(label="Document Paragraph 2 (optional)", placeholder="Enter another paragraph...", lines=4),
        gr.Textbox(label="Document Paragraph 3 (optional)", placeholder="Enter another paragraph...", lines=4),
        gr.Slider(minimum=0.02, maximum=0.5, value=0.1, step=0.01, label="Similarity Threshold")
    ],
    outputs=[gr.HTML(label="Ranked Paragraphs")],
    title="Embedding-Based Similarity Highlighting",
    description="Uses cosine similarity with Supabase/gte-small embeddings to rank paragraphs and highlight relevant words.",
    allow_flagging="never",
    live=True
)

if __name__ == "__main__":
    interface.launch()