ohalkhateeb commited on
Commit
309766f
·
verified ·
1 Parent(s): aa8c5f5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +53 -35
app.py CHANGED
@@ -1,43 +1,61 @@
1
- import os
2
- from bs4 import BeautifulSoup
3
- from langchain.text_splitter import RecursiveCharacterTextSplitter
4
- from sentence_transformers import SentenceTransformer
5
  import faiss
6
  import numpy as np
7
  import pickle
 
 
8
 
9
- # Function to load documents
10
- def load_documents(directory):
11
- documents = []
12
- for filename in os.listdir(directory):
13
- if filename.endswith(".html"):
14
- file_path = os.path.join(directory, filename)
15
- with open(file_path, "r", encoding="latin-1") as f:
16
- soup = BeautifulSoup(f, "html.parser")
17
- text = soup.get_text(separator=" ", strip=True)
18
- documents.append(text)
19
- return documents
20
-
21
- # Load and split documents
22
- print("Loading and splitting documents...")
23
- documents = load_documents("./legislation")
24
- text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
25
- chunks = []
26
- for doc in documents:
27
- chunks.extend(text_splitter.split_text(doc))
28
 
29
- # Create embeddings and FAISS index
30
- print("Generating embeddings...")
31
  embedding_model = SentenceTransformer("sentence-transformers/paraphrase-multilingual-mpnet-base-v2")
32
- embeddings = embedding_model.encode(chunks, show_progress_bar=True)
33
- dimension = embeddings.shape[1]
34
- index = faiss.IndexFlatL2(dimension)
35
- index.add(np.array(embeddings))
36
 
37
- # Save chunks and index
38
- print("Saving precomputed data...")
39
- with open("chunks.pkl", "wb") as f:
40
- pickle.dump(chunks, f)
41
- faiss.write_index(index, "index.faiss")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
 
43
- print("Preprocessing complete!")
 
1
+ import gradio as gr
 
 
 
2
  import faiss
3
  import numpy as np
4
  import pickle
5
+ from sentence_transformers import SentenceTransformer
6
+ from transformers import AutoTokenizer, AutoModelForCausalLM
7
 
8
+ # Load precomputed chunks and FAISS index
9
+ print("Loading precomputed data...")
10
+ with open("chunks.pkl", "rb") as f:
11
+ chunks = pickle.load(f)
12
+ index = faiss.read_index("index.faiss")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
 
14
+ # Load embedding model (for queries only)
 
15
  embedding_model = SentenceTransformer("sentence-transformers/paraphrase-multilingual-mpnet-base-v2")
 
 
 
 
16
 
17
+ # Load Jais model and tokenizer
18
+ model_name = "inceptionai/jais-13b"
19
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
20
+ model = AutoModelForCausalLM.from_pretrained(model_name)
21
+
22
+ # RAG function
23
+ def get_response(query, k=3):
24
+ query_embedding = embedding_model.encode([query])
25
+ distances, indices = index.search(np.array(query_embedding), k)
26
+ retrieved_chunks = [chunks[i] for i in indices[0]]
27
+ context = " ".join(retrieved_chunks)
28
+ prompt = f"استنادًا إلى الوثائق التالية: {context}، أجب على السؤال: {query}"
29
+ inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512)
30
+ outputs = model.generate(
31
+ **inputs,
32
+ max_new_tokens=200,
33
+ do_sample=True,
34
+ temperature=0.7,
35
+ top_p=0.9
36
+ )
37
+ response = tokenizer.decode(outputs[0], skip_special_tokens=True)
38
+ return response.split(query)[-1].strip()
39
+
40
+ # Gradio interface
41
+ with gr.Blocks(title="Dubai Legislation Chatbot") as demo:
42
+ gr.Markdown("# Dubai Legislation Chatbot\nاسأل أي سؤال حول تشريعات دبي")
43
+ chatbot = gr.Chatbot()
44
+ msg = gr.Textbox(placeholder="اكتب سؤالك هنا...", rtl=True)
45
+ clear = gr.Button("مسح")
46
+
47
+ def user(user_message, history):
48
+ return "", history + [[user_message, None]]
49
+
50
+ def bot(history):
51
+ user_message = history[-1][0]
52
+ bot_message = get_response(user_message)
53
+ history[-1][1] = bot_message
54
+ return history
55
+
56
+ msg.submit(user, [msg, chatbot], [msg, chatbot], queue=False).then(
57
+ bot, chatbot, chatbot
58
+ )
59
+ clear.click(lambda: None, None, chatbot, queue=False)
60
 
61
+ demo.launch()