import gradio as gr from transformers import AutoModel, AutoTokenizer import pandas as pd import torch import torch.nn.functional as F from torch import Tensor # --- Configuration --- MODEL_NAME = "Qwen/Qwen3-Embedding-0.6B" # Placeholder for your model KURAL_EMBEDDINGS_FILE = "kural_embeds.pt" KURAL_DATA_FILE = "thirukural.tsv" # You'll need a CSV with the Kural text # --- Load Resources --- try: tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) model = AutoModel.from_pretrained(MODEL_NAME, trust_remote_code=True) except Exception as e: print(f"Error loading Transformer model: {e}") # Handle model loading failure (e.g., exit or use a fallback) try: kural_embeddings = torch.load(KURAL_EMBEDDINGS_FILE) except FileNotFoundError: print(f"Error: The file {KURAL_EMBEDDINGS_FILE} was not found.") try: df = pd.read_csv(KURAL_DATA_FILE, sep='\t') except FileNotFoundError: print(f"Error: The file {KURAL_DATA_FILE} was not found.") def get_detailed_instruct(query: str) -> str: return f'Instruct: Given a question, retrieve relevant Thirukkural couplets that are most relevant to, or answer the question\nQuery:{query}' def last_token_pool(last_hidden_states: Tensor, attention_mask: Tensor) -> Tensor: ''' Returns pooled embedding of last token from Qwen3 ''' left_padding = (attention_mask[:, -1].sum() == attention_mask.shape[0]) if left_padding: return last_hidden_states[:, -1] else: sequence_lengths = attention_mask.sum(dim=1) - 1 batch_size = last_hidden_states.shape[0] return last_hidden_states[torch.arange(batch_size, device=last_hidden_states.device), sequence_lengths] def find_relevant_kurals(question): """ Finds the top 5 most relevant Kurals using cosine similarity. """ batch_dict = tokenizer([get_detailed_instruct(question)], max_length=128, padding=False, truncation=True, return_tensors='pt') outputs = model(**batch_dict) query_embedding = last_token_pool(outputs.last_hidden_state, batch_dict['attention_mask']).detach().cpu() # Calculate similarities all_embeddings = torch.cat((query_embedding,kural_embeddings), axis=0) all_embeddings = F.normalize(all_embeddings, p=2, dim=1) scores = all_embeddings[:1]@all_embeddings[1:].T # Get top 5 indices top_indices = torch.topk(scores[0,:], 3).indices.tolist() # Prepare results results = [] for i in top_indices: results.append({ "kural_ta": df.iloc[i].get("kural", "N/A"), "kural_eng": df.iloc[i].get("kural_eng", "N/A"), "chapter": df.iloc[i].get("chapter", "N/A"), "similarity": scores[0,i] }) return results def rag_interface(question): """ The main function for the Gradio interface. """ if not question: return "Please enter a question." kurals = find_relevant_kurals(question) output = "" for kural in kurals: output += f"**Kural (Tamil):** {kural['kural_ta']}
" output += f"**Kural (English):** {kural['kural_eng']}
" output += f"**Chapter:** {kural['chapter']}
" output += f"**Similarity:** {kural['similarity']:.2f}\n\n---\n" return output # --- Gradio Interface --- iface = gr.Interface( fn=rag_interface, inputs=gr.Textbox(lines=2, placeholder="Enter your question here:"), outputs="markdown", title="Kural for your question", description="Ask a vexing question and get 3 relevant Thirukural couplets using embedding-similarity based search.", flagging_mode='never' ) if __name__ == "__main__": iface.launch()