venkatasg commited on
Commit
641f5e1
·
verified ·
1 Parent(s): b13650d
Files changed (4) hide show
  1. app.py +103 -0
  2. kural_embeds.pt +3 -0
  3. requirements.txt +5 -0
  4. thirukural.tsv +0 -0
app.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from transformers import AutoModel, AutoTokenizer
3
+ import pandas as pd
4
+ import torch
5
+ import torch.nn.functional as F
6
+ from torch import Tensor
7
+
8
+ # --- Configuration ---
9
+ MODEL_NAME = "Qwen/Qwen3-Embedding-0.6B" # Placeholder for your model
10
+ KURAL_EMBEDDINGS_FILE = "kural_embeds.pt"
11
+ KURAL_DATA_FILE = "thirukural.tsv" # You'll need a CSV with the Kural text
12
+
13
+ # --- Load Resources ---
14
+ try:
15
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
16
+ model = AutoModel.from_pretrained(MODEL_NAME, trust_remote_code=True)
17
+ except Exception as e:
18
+ print(f"Error loading Transformer model: {e}")
19
+ # Handle model loading failure (e.g., exit or use a fallback)
20
+
21
+ try:
22
+ kural_embeddings = torch.load(KURAL_EMBEDDINGS_FILE)
23
+ except FileNotFoundError:
24
+ print(f"Error: The file {KURAL_EMBEDDINGS_FILE} was not found.")
25
+
26
+ try:
27
+ df = pd.read_csv(KURAL_DATA_FILE, sep='\t')
28
+ except FileNotFoundError:
29
+ print(f"Error: The file {KURAL_DATA_FILE} was not found.")
30
+
31
+ def get_detailed_instruct(query: str) -> str:
32
+ return f'Instruct: Given a question, retrieve relevant Thirukkural couplets that are most relevant to, or answer the question\nQuery:{query}'
33
+
34
+ def last_token_pool(last_hidden_states: Tensor, attention_mask: Tensor) -> Tensor:
35
+ '''
36
+ Returns pooled embedding of last token from Qwen3
37
+ '''
38
+ left_padding = (attention_mask[:, -1].sum() == attention_mask.shape[0])
39
+ if left_padding:
40
+ return last_hidden_states[:, -1]
41
+ else:
42
+ sequence_lengths = attention_mask.sum(dim=1) - 1
43
+ batch_size = last_hidden_states.shape[0]
44
+ return last_hidden_states[torch.arange(batch_size, device=last_hidden_states.device), sequence_lengths]
45
+
46
+ def find_relevant_kurals(question):
47
+ """
48
+ Finds the top 5 most relevant Kurals using cosine similarity.
49
+ """
50
+ batch_dict = tokenizer([get_detailed_instruct(question)], max_length=128, padding=False, truncation=True, return_tensors='pt')
51
+ outputs = model(**batch_dict)
52
+ query_embedding = last_token_pool(outputs.last_hidden_state, batch_dict['attention_mask']).detach().cpu()
53
+
54
+ # Calculate similarities
55
+ all_embeddings = torch.cat((query_embedding,kural_embeddings), axis=0)
56
+ all_embeddings = F.normalize(all_embeddings, p=2, dim=1)
57
+ scores = all_embeddings[:1]@all_embeddings[1:].T
58
+
59
+
60
+ # Get top 5 indices
61
+ top_indices = torch.topk(scores[0,:], 3).indices.tolist()
62
+
63
+ # Prepare results
64
+ results = []
65
+ for i in top_indices:
66
+ results.append({
67
+ "kural_ta": df.iloc[i].get("kural", "N/A"),
68
+ "kural_eng": df.iloc[i].get("kural_eng", "N/A"),
69
+ "chapter": df.iloc[i].get("chapter", "N/A"),
70
+ "similarity": scores[0,i]
71
+ })
72
+ return results
73
+
74
+ def rag_interface(question):
75
+ """
76
+ The main function for the Gradio interface.
77
+ """
78
+ if not question:
79
+ return "Please enter a question."
80
+
81
+ kurals = find_relevant_kurals(question)
82
+
83
+ output = ""
84
+ for kural in kurals:
85
+ output += f"**Kural (Tamil):** {kural['kural_ta']}<br>"
86
+ output += f"**Kural (English):** {kural['kural_eng']}<br>"
87
+ output += f"**Chapter:** {kural['chapter']}<br>"
88
+ output += f"**Similarity:** {kural['similarity']:.2f}\n\n---\n"
89
+
90
+ return output
91
+
92
+ # --- Gradio Interface ---
93
+ iface = gr.Interface(
94
+ fn=rag_interface,
95
+ inputs=gr.Textbox(lines=2, placeholder="Enter your question here:"),
96
+ outputs="markdown",
97
+ title="Kural for your question.",
98
+ description="Ask a vexing question and get 3 relevant Thirukural couplets using embedding-similarity based search.",
99
+ flagging_mode='never'
100
+ )
101
+
102
+ if __name__ == "__main__":
103
+ iface.launch()
kural_embeds.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:fe85848bea6c0971803d987dccec28bff72e534fc4cf4f95558b20218b17e7c2
3
+ size 5448885
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ gradio
2
+ transformers
3
+ pandas
4
+ numpy
5
+ torch
thirukural.tsv ADDED
The diff for this file is too large to render. See raw diff