upload
Browse files- app.py +103 -0
- kural_embeds.pt +3 -0
- requirements.txt +5 -0
- thirukural.tsv +0 -0
app.py
ADDED
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
from transformers import AutoModel, AutoTokenizer
|
3 |
+
import pandas as pd
|
4 |
+
import torch
|
5 |
+
import torch.nn.functional as F
|
6 |
+
from torch import Tensor
|
7 |
+
|
8 |
+
# --- Configuration ---
|
9 |
+
MODEL_NAME = "Qwen/Qwen3-Embedding-0.6B" # Placeholder for your model
|
10 |
+
KURAL_EMBEDDINGS_FILE = "kural_embeds.pt"
|
11 |
+
KURAL_DATA_FILE = "thirukural.tsv" # You'll need a CSV with the Kural text
|
12 |
+
|
13 |
+
# --- Load Resources ---
|
14 |
+
try:
|
15 |
+
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
|
16 |
+
model = AutoModel.from_pretrained(MODEL_NAME, trust_remote_code=True)
|
17 |
+
except Exception as e:
|
18 |
+
print(f"Error loading Transformer model: {e}")
|
19 |
+
# Handle model loading failure (e.g., exit or use a fallback)
|
20 |
+
|
21 |
+
try:
|
22 |
+
kural_embeddings = torch.load(KURAL_EMBEDDINGS_FILE)
|
23 |
+
except FileNotFoundError:
|
24 |
+
print(f"Error: The file {KURAL_EMBEDDINGS_FILE} was not found.")
|
25 |
+
|
26 |
+
try:
|
27 |
+
df = pd.read_csv(KURAL_DATA_FILE, sep='\t')
|
28 |
+
except FileNotFoundError:
|
29 |
+
print(f"Error: The file {KURAL_DATA_FILE} was not found.")
|
30 |
+
|
31 |
+
def get_detailed_instruct(query: str) -> str:
|
32 |
+
return f'Instruct: Given a question, retrieve relevant Thirukkural couplets that are most relevant to, or answer the question\nQuery:{query}'
|
33 |
+
|
34 |
+
def last_token_pool(last_hidden_states: Tensor, attention_mask: Tensor) -> Tensor:
|
35 |
+
'''
|
36 |
+
Returns pooled embedding of last token from Qwen3
|
37 |
+
'''
|
38 |
+
left_padding = (attention_mask[:, -1].sum() == attention_mask.shape[0])
|
39 |
+
if left_padding:
|
40 |
+
return last_hidden_states[:, -1]
|
41 |
+
else:
|
42 |
+
sequence_lengths = attention_mask.sum(dim=1) - 1
|
43 |
+
batch_size = last_hidden_states.shape[0]
|
44 |
+
return last_hidden_states[torch.arange(batch_size, device=last_hidden_states.device), sequence_lengths]
|
45 |
+
|
46 |
+
def find_relevant_kurals(question):
|
47 |
+
"""
|
48 |
+
Finds the top 5 most relevant Kurals using cosine similarity.
|
49 |
+
"""
|
50 |
+
batch_dict = tokenizer([get_detailed_instruct(question)], max_length=128, padding=False, truncation=True, return_tensors='pt')
|
51 |
+
outputs = model(**batch_dict)
|
52 |
+
query_embedding = last_token_pool(outputs.last_hidden_state, batch_dict['attention_mask']).detach().cpu()
|
53 |
+
|
54 |
+
# Calculate similarities
|
55 |
+
all_embeddings = torch.cat((query_embedding,kural_embeddings), axis=0)
|
56 |
+
all_embeddings = F.normalize(all_embeddings, p=2, dim=1)
|
57 |
+
scores = all_embeddings[:1]@all_embeddings[1:].T
|
58 |
+
|
59 |
+
|
60 |
+
# Get top 5 indices
|
61 |
+
top_indices = torch.topk(scores[0,:], 3).indices.tolist()
|
62 |
+
|
63 |
+
# Prepare results
|
64 |
+
results = []
|
65 |
+
for i in top_indices:
|
66 |
+
results.append({
|
67 |
+
"kural_ta": df.iloc[i].get("kural", "N/A"),
|
68 |
+
"kural_eng": df.iloc[i].get("kural_eng", "N/A"),
|
69 |
+
"chapter": df.iloc[i].get("chapter", "N/A"),
|
70 |
+
"similarity": scores[0,i]
|
71 |
+
})
|
72 |
+
return results
|
73 |
+
|
74 |
+
def rag_interface(question):
|
75 |
+
"""
|
76 |
+
The main function for the Gradio interface.
|
77 |
+
"""
|
78 |
+
if not question:
|
79 |
+
return "Please enter a question."
|
80 |
+
|
81 |
+
kurals = find_relevant_kurals(question)
|
82 |
+
|
83 |
+
output = ""
|
84 |
+
for kural in kurals:
|
85 |
+
output += f"**Kural (Tamil):** {kural['kural_ta']}<br>"
|
86 |
+
output += f"**Kural (English):** {kural['kural_eng']}<br>"
|
87 |
+
output += f"**Chapter:** {kural['chapter']}<br>"
|
88 |
+
output += f"**Similarity:** {kural['similarity']:.2f}\n\n---\n"
|
89 |
+
|
90 |
+
return output
|
91 |
+
|
92 |
+
# --- Gradio Interface ---
|
93 |
+
iface = gr.Interface(
|
94 |
+
fn=rag_interface,
|
95 |
+
inputs=gr.Textbox(lines=2, placeholder="Enter your question here:"),
|
96 |
+
outputs="markdown",
|
97 |
+
title="Kural for your question.",
|
98 |
+
description="Ask a vexing question and get 3 relevant Thirukural couplets using embedding-similarity based search.",
|
99 |
+
flagging_mode='never'
|
100 |
+
)
|
101 |
+
|
102 |
+
if __name__ == "__main__":
|
103 |
+
iface.launch()
|
kural_embeds.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:fe85848bea6c0971803d987dccec28bff72e534fc4cf4f95558b20218b17e7c2
|
3 |
+
size 5448885
|
requirements.txt
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
gradio
|
2 |
+
transformers
|
3 |
+
pandas
|
4 |
+
numpy
|
5 |
+
torch
|
thirukural.tsv
ADDED
The diff for this file is too large to render.
See raw diff
|
|