import gradio as gr
from transformers import AutoModel, AutoTokenizer
import pandas as pd
import torch
import torch.nn.functional as F
from torch import Tensor
# --- Configuration ---
MODEL_NAME = "Qwen/Qwen3-Embedding-0.6B" # Placeholder for your model
KURAL_EMBEDDINGS_FILE = "kural_embeds.pt"
KURAL_DATA_FILE = "thirukural.tsv" # You'll need a CSV with the Kural text
# --- Load Resources ---
try:
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModel.from_pretrained(MODEL_NAME, trust_remote_code=True)
except Exception as e:
print(f"Error loading Transformer model: {e}")
# Handle model loading failure (e.g., exit or use a fallback)
try:
kural_embeddings = torch.load(KURAL_EMBEDDINGS_FILE)
except FileNotFoundError:
print(f"Error: The file {KURAL_EMBEDDINGS_FILE} was not found.")
try:
df = pd.read_csv(KURAL_DATA_FILE, sep='\t')
except FileNotFoundError:
print(f"Error: The file {KURAL_DATA_FILE} was not found.")
def get_detailed_instruct(query: str) -> str:
return f'Instruct: Given a question, retrieve relevant Thirukkural couplets that are most relevant to, or answer the question\nQuery:{query}'
def last_token_pool(last_hidden_states: Tensor, attention_mask: Tensor) -> Tensor:
'''
Returns pooled embedding of last token from Qwen3
'''
left_padding = (attention_mask[:, -1].sum() == attention_mask.shape[0])
if left_padding:
return last_hidden_states[:, -1]
else:
sequence_lengths = attention_mask.sum(dim=1) - 1
batch_size = last_hidden_states.shape[0]
return last_hidden_states[torch.arange(batch_size, device=last_hidden_states.device), sequence_lengths]
def find_relevant_kurals(question):
"""
Finds the top 5 most relevant Kurals using cosine similarity.
"""
batch_dict = tokenizer([get_detailed_instruct(question)], max_length=128, padding=False, truncation=True, return_tensors='pt')
outputs = model(**batch_dict)
query_embedding = last_token_pool(outputs.last_hidden_state, batch_dict['attention_mask']).detach().cpu()
# Calculate similarities
all_embeddings = torch.cat((query_embedding,kural_embeddings), axis=0)
all_embeddings = F.normalize(all_embeddings, p=2, dim=1)
scores = all_embeddings[:1]@all_embeddings[1:].T
# Get top 5 indices
top_indices = torch.topk(scores[0,:], 3).indices.tolist()
# Prepare results
results = []
for i in top_indices:
results.append({
"kural_ta": df.iloc[i].get("kural", "N/A"),
"kural_eng": df.iloc[i].get("kural_eng", "N/A"),
"chapter": df.iloc[i].get("chapter", "N/A"),
"similarity": scores[0,i]
})
return results
def rag_interface(question):
"""
The main function for the Gradio interface.
"""
if not question:
return "Please enter a question."
kurals = find_relevant_kurals(question)
output = ""
for kural in kurals:
output += f"**Kural (Tamil):** {kural['kural_ta']}
"
output += f"**Kural (English):** {kural['kural_eng']}
"
output += f"**Chapter:** {kural['chapter']}
"
output += f"**Similarity:** {kural['similarity']:.2f}\n\n---\n"
return output
# --- Gradio Interface ---
iface = gr.Interface(
fn=rag_interface,
inputs=gr.Textbox(lines=2, placeholder="Enter your question here:"),
outputs="markdown",
title="Kural for your question",
description="Ask a vexing question and get 3 relevant Thirukural couplets using embedding-similarity based search.",
flagging_mode='never'
)
if __name__ == "__main__":
iface.launch()