ohalkhateeb commited on
Commit
3d77d1e
·
verified ·
1 Parent(s): 3319c10

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +68 -0
app.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import faiss
3
+ import numpy as np
4
+ import pickle
5
+ from sentence_transformers import SentenceTransformer
6
+ from transformers import AutoTokenizer, AutoModelForCausalLM
7
+
8
+
9
+ HF_TOKEN = os.getenv("HF_TOKEN")
10
+ if not HF_TOKEN:
11
+ raise ValueError("HF_TOKEN environment variable not set. Please configure it in Space settings.")
12
+
13
+
14
+ # Load precomputed chunks and FAISS index
15
+ with open("chunks.pkl", "rb") as f:
16
+ chunks = pickle.load(f)
17
+ index = faiss.read_index("index.faiss")
18
+
19
+ # Load embedding model (same as used in preprocessing)
20
+ embedding_model = SentenceTransformer("sentence-transformers/paraphrase-multilingual-mpnet-base-v2")
21
+
22
+
23
+ # Load Jais model and tokenizer
24
+ model_name = "aubmindlab/aragpt2-base"
25
+ tokenizer = AutoTokenizer.from_pretrained(model_name, token=HF_TOKEN, trust_remote_code=True)
26
+ model = AutoModelForCausalLM.from_pretrained(model_name, token=HF_TOKEN, trust_remote_code=True)
27
+
28
+
29
+ # RAG function to retrieve and generate a response
30
+ def get_response(query, k=3):
31
+ query_embedding = embedding_model.encode([query])
32
+ distances, indices = index.search(np.array(query_embedding), k)
33
+ retrieved_chunks = [chunks[i] for i in indices[0]]
34
+ context = " ".join(retrieved_chunks)
35
+ prompt = f"Based on the following documents: {context}, answer the question: {query}"
36
+ inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512)
37
+ outputs = model.generate(
38
+ **inputs,
39
+ max_new_tokens=200,
40
+ do_sample=True,
41
+ temperature=0.7,
42
+ top_p=0.9
43
+ )
44
+ response = tokenizer.decode(outputs[0], skip_special_tokens=True)
45
+ return response.split(query)[-1].strip()
46
+
47
+ # Gradio interface
48
+ with gr.Blocks(title="Dubai Legislation Chatbot") as demo:
49
+ gr.Markdown("# Dubai Legislation Chatbot\nAsk any question about Dubai legislation")
50
+ chatbot = gr.Chatbot()
51
+ msg = gr.Textbox(placeholder="Type your question here...")
52
+ clear = gr.Button("Clear")
53
+
54
+ def user(user_message, history):
55
+ return "", history + [[user_message, None]]
56
+
57
+ def bot(history):
58
+ user_message = history[-1][0]
59
+ bot_message = get_response(user_message)
60
+ history[-1][1] = bot_message
61
+ return history
62
+
63
+ msg.submit(user, [msg, chatbot], [msg, chatbot], queue=False).then(
64
+ bot, chatbot, chatbot
65
+ )
66
+ clear.click(lambda: None, None, chatbot, queue=False)
67
+
68
+ demo.launch()