ohalkhateeb commited on
Commit
c30f47d
·
verified ·
1 Parent(s): a55bcb6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +49 -0
app.py CHANGED
@@ -1,3 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
 
3
  # Print current Gradio version
 
1
+ import os
2
+ import gradio as gr
3
+ import faiss
4
+ import numpy as np
5
+ import pickle
6
+ from sentence_transformers import SentenceTransformer
7
+ from transformers import AutoTokenizer, AutoModelForCausalLM
8
+
9
+
10
+ HF_TOKEN = os.getenv("HF_TOKEN")
11
+ if not HF_TOKEN:
12
+ raise ValueError("HF_TOKEN environment variable not set. Please configure it in Space settings.")
13
+
14
+
15
+ # Load precomputed chunks and FAISS index
16
+ with open("chunks.pkl", "rb") as f:
17
+ chunks = pickle.load(f)
18
+ index = faiss.read_index("index.faiss")
19
+
20
+ # Load embedding model (same as used in preprocessing)
21
+ embedding_model = SentenceTransformer("sentence-transformers/paraphrase-multilingual-mpnet-base-v2")
22
+
23
+
24
+ # Load Jais model and tokenizer
25
+ model_name = "aubmindlab/aragpt2-base"
26
+ tokenizer = AutoTokenizer.from_pretrained(model_name, token=HF_TOKEN, trust_remote_code=True)
27
+ model = AutoModelForCausalLM.from_pretrained(model_name, token=HF_TOKEN, trust_remote_code=True)
28
+
29
+
30
+ # RAG function to retrieve and generate a response
31
+ def get_response(query, k=3):
32
+ query_embedding = embedding_model.encode([query])
33
+ distances, indices = index.search(np.array(query_embedding), k)
34
+ retrieved_chunks = [chunks[i] for i in indices[0]]
35
+ context = " ".join(retrieved_chunks)
36
+ prompt = f"Based on the following documents: {context}, answer the question: {query}"
37
+ inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512)
38
+ outputs = model.generate(
39
+ **inputs,
40
+ max_new_tokens=200,
41
+ do_sample=True,
42
+ temperature=0.7,
43
+ top_p=0.9
44
+ )
45
+ response = tokenizer.decode(outputs[0], skip_special_tokens=True)
46
+ return response.split(query)[-1].strip()
47
+
48
+ # Gradio interface
49
+
50
  import gradio as gr
51
 
52
  # Print current Gradio version