Files changed (1) hide show
  1. app.py +307 -94
app.py CHANGED
@@ -1,10 +1,15 @@
1
  import gradio as gr
2
  import os
 
 
 
 
 
3
  api_token = os.getenv("HF_TOKEN")
4
 
5
 
6
  from langchain_community.vectorstores import FAISS
7
- from langchain_community.document_loaders import PyPDFLoader
8
  from langchain.text_splitter import RecursiveCharacterTextSplitter
9
  from langchain_community.vectorstores import Chroma
10
  from langchain.chains import ConversationalRetrievalChain
@@ -20,56 +25,130 @@ import torch
20
  list_llm = ["meta-llama/Meta-Llama-3-8B-Instruct", "mistralai/Mistral-7B-Instruct-v0.2"]
21
  list_llm_simple = [os.path.basename(llm) for llm in list_llm]
22
 
23
- # Load and split PDF document
24
- def load_doc(list_file_path):
25
- # Processing for one document only
26
- # loader = PyPDFLoader(file_path)
27
- # pages = loader.load()
28
- loaders = [PyPDFLoader(x) for x in list_file_path]
29
- pages = []
30
- for loader in loaders:
31
- pages.extend(loader.load())
32
- text_splitter = RecursiveCharacterTextSplitter(
33
- chunk_size = 1024,
34
- chunk_overlap = 64
35
- )
36
- doc_splits = text_splitter.split_documents(pages)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
  return doc_splits
38
 
 
 
 
 
 
 
 
 
39
  # Create vector database
40
- def create_db(splits):
 
41
  embeddings = HuggingFaceEmbeddings()
42
- vectordb = FAISS.from_documents(splits, embeddings)
 
 
 
 
 
 
 
 
43
  return vectordb
44
 
45
 
46
  # Initialize langchain LLM chain
47
- # @spaces.GPU(duration=60)
48
  def initialize_llmchain(llm_model, temperature, max_tokens, top_k, vector_db, progress=gr.Progress()):
 
 
49
  if llm_model == "meta-llama/Meta-Llama-3-8B-Instruct":
50
  llm = HuggingFaceEndpoint(
51
  repo_id=llm_model,
52
- huggingfacehub_api_token = api_token,
53
- temperature = temperature,
54
- max_new_tokens = max_tokens,
55
- top_k = top_k,
56
  )
57
  else:
58
  llm = HuggingFaceEndpoint(
59
- huggingfacehub_api_token = api_token,
60
  repo_id=llm_model,
61
- temperature = temperature,
62
- max_new_tokens = max_tokens,
63
- top_k = top_k,
64
  )
65
 
 
 
66
  memory = ConversationBufferMemory(
67
  memory_key="chat_history",
68
  output_key='answer',
69
  return_messages=True
70
  )
71
 
72
- retriever=vector_db.as_retriever()
73
  qa_chain = ConversationalRetrievalChain.from_llm(
74
  llm,
75
  retriever=retriever,
@@ -78,27 +157,46 @@ def initialize_llmchain(llm_model, temperature, max_tokens, top_k, vector_db, pr
78
  return_source_documents=True,
79
  verbose=False,
80
  )
 
 
81
  return qa_chain
82
 
83
  # Initialize database
84
- # @spaces.GPU(duration=60)
85
  def initialize_database(list_file_obj, progress=gr.Progress()):
86
  # Create a list of documents (when valid)
87
  list_file_path = [x.name for x in list_file_obj if x is not None]
 
 
 
 
88
  # Load document and create splits
89
- doc_splits = load_doc(list_file_path)
 
 
 
 
90
  # Create or load vector database
91
- vector_db = create_db(doc_splits)
92
- return vector_db, "Database created!"
 
 
 
 
 
 
 
 
 
93
 
94
  # Initialize LLM
95
- # @spaces.GPU(duration=60)
96
  def initialize_LLM(llm_option, llm_temperature, max_tokens, top_k, vector_db, progress=gr.Progress()):
97
- # print("llm_option",llm_option)
 
 
98
  llm_name = list_llm[llm_option]
99
- print("llm_name: ",llm_name)
100
  qa_chain = initialize_llmchain(llm_name, llm_temperature, max_tokens, top_k, vector_db, progress)
101
- return qa_chain, "QA chain initialized. Chatbot is ready!"
102
 
103
 
104
  def format_chat_history(message, chat_history):
@@ -108,73 +206,166 @@ def format_chat_history(message, chat_history):
108
  formatted_chat_history.append(f"Assistant: {bot_message}")
109
  return formatted_chat_history
110
 
111
- # @spaces.GPU(duration=60)
112
  def conversation(qa_chain, message, history):
 
 
 
113
  formatted_chat_history = format_chat_history(message, history)
114
  # Generate response using QA chain
115
  response = qa_chain.invoke({"question": message, "chat_history": formatted_chat_history})
116
  response_answer = response["answer"]
117
  if response_answer.find("Helpful Answer:") != -1:
118
  response_answer = response_answer.split("Helpful Answer:")[-1]
 
119
  response_sources = response["source_documents"]
120
- response_source1 = response_sources[0].page_content.strip()
121
- response_source2 = response_sources[1].page_content.strip()
122
- response_source3 = response_sources[2].page_content.strip()
123
- # Langchain sources are zero-based
124
- response_source1_page = response_sources[0].metadata["page"] + 1
125
- response_source2_page = response_sources[1].metadata["page"] + 1
126
- response_source3_page = response_sources[2].metadata["page"] + 1
 
 
 
 
 
 
 
127
  # Append user message and response to chat history
128
  new_history = history + [(message, response_answer)]
129
- return qa_chain, gr.update(value=""), new_history, response_source1, response_source1_page, response_source2, response_source2_page, response_source3, response_source3_page
130
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
131
 
132
- def upload_file(file_obj):
133
- list_file_path = []
134
- for idx, file in enumerate(file_obj):
135
- file_path = file_obj.name
136
- list_file_path.append(file_path)
137
- return list_file_path
 
138
 
139
 
140
  def demo():
141
- # with gr.Blocks(theme=gr.themes.Default(primary_hue="sky")) as demo:
142
- with gr.Blocks(theme=gr.themes.Default(primary_hue="red", secondary_hue="pink", neutral_hue = "sky")) as demo:
143
  vector_db = gr.State()
144
  qa_chain = gr.State()
145
- gr.HTML("<center><h1>RAG PDF chatbot</h1><center>")
146
- gr.Markdown("""<b>Query your PDF documents!</b> This AI agent is designed to perform retrieval augmented generation (RAG) on PDF documents. The app is hosted on Hugging Face Hub for the sole purpose of demonstration. \
147
- <b>Please do not upload confidential documents.</b>
 
148
  """)
 
149
  with gr.Row():
150
- with gr.Column(scale = 86):
151
- gr.Markdown("<b>Step 1 - Upload PDF documents and Initialize RAG pipeline</b>")
152
  with gr.Row():
153
- document = gr.Files(height=300, file_count="multiple", file_types=["pdf"], interactive=True, label="Upload PDF documents")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
154
  with gr.Row():
155
- db_btn = gr.Button("Create vector database")
 
156
  with gr.Row():
157
- db_progress = gr.Textbox(value="Not initialized", show_label=False) # label="Vector database status",
158
- gr.Markdown("<style>body { font-size: 16px; }</style><b>Select Large Language Model (LLM) and input parameters</b>")
 
 
 
 
 
159
  with gr.Row():
160
- llm_btn = gr.Radio(list_llm_simple, label="Available LLMs", value = list_llm_simple[0], type="index") # info="Select LLM", show_label=False
 
 
 
 
 
 
161
  with gr.Row():
162
- with gr.Accordion("LLM input parameters", open=False):
163
  with gr.Row():
164
- slider_temperature = gr.Slider(minimum = 0.01, maximum = 1.0, value=0.5, step=0.1, label="Temperature", info="Controls randomness in token generation", interactive=True)
 
 
 
 
 
 
 
 
165
  with gr.Row():
166
- slider_maxtokens = gr.Slider(minimum = 128, maximum = 9192, value=4096, step=128, label="Max New Tokens", info="Maximum number of tokens to be generated",interactive=True)
 
 
 
 
 
 
 
 
167
  with gr.Row():
168
- slider_topk = gr.Slider(minimum = 1, maximum = 10, value=3, step=1, label="top-k", info="Number of tokens to select the next token from", interactive=True)
 
 
 
 
 
 
 
 
 
169
  with gr.Row():
170
- qachain_btn = gr.Button("Initialize Question Answering Chatbot")
 
171
  with gr.Row():
172
- llm_progress = gr.Textbox(value="Not initialized", show_label=False) # label="Chatbot status",
 
 
 
 
173
 
174
- with gr.Column(scale = 200):
175
- gr.Markdown("<b>Step 2 - Chat with your Document</b>")
176
  chatbot = gr.Chatbot(height=505)
177
- with gr.Accordion("Relevent context from the source document", open=False):
 
178
  with gr.Row():
179
  doc_source1 = gr.Textbox(label="Reference 1", lines=2, container=True, scale=20)
180
  source1_page = gr.Number(label="Page", scale=1)
@@ -184,36 +375,58 @@ def demo():
184
  with gr.Row():
185
  doc_source3 = gr.Textbox(label="Reference 3", lines=2, container=True, scale=20)
186
  source3_page = gr.Number(label="Page", scale=1)
 
187
  with gr.Row():
188
- msg = gr.Textbox(placeholder="Ask a question", container=True)
 
 
 
 
 
189
  with gr.Row():
190
- submit_btn = gr.Button("Submit")
191
  clear_btn = gr.ClearButton([msg, chatbot], value="Clear")
192
 
193
  # Preprocessing events
194
- db_btn.click(initialize_database, \
195
- inputs=[document], \
196
- outputs=[vector_db, db_progress])
197
- qachain_btn.click(initialize_LLM, \
198
- inputs=[llm_btn, slider_temperature, slider_maxtokens, slider_topk, vector_db], \
199
- outputs=[qa_chain, llm_progress]).then(lambda:[None,"",0,"",0,"",0], \
200
- inputs=None, \
201
- outputs=[chatbot, doc_source1, source1_page, doc_source2, source2_page, doc_source3, source3_page], \
202
- queue=False)
 
 
 
 
 
 
 
203
 
204
  # Chatbot events
205
- msg.submit(conversation, \
206
- inputs=[qa_chain, msg, chatbot], \
207
- outputs=[qa_chain, msg, chatbot, doc_source1, source1_page, doc_source2, source2_page, doc_source3, source3_page], \
208
- queue=False)
209
- submit_btn.click(conversation, \
210
- inputs=[qa_chain, msg, chatbot], \
211
- outputs=[qa_chain, msg, chatbot, doc_source1, source1_page, doc_source2, source2_page, doc_source3, source3_page], \
212
- queue=False)
213
- clear_btn.click(lambda:[None,"",0,"",0,"",0], \
214
- inputs=None, \
215
- outputs=[chatbot, doc_source1, source1_page, doc_source2, source2_page, doc_source3, source3_page], \
216
- queue=False)
 
 
 
 
 
 
 
 
 
217
  demo.queue().launch(debug=True)
218
 
219
 
 
1
  import gradio as gr
2
  import os
3
+ from pathlib import Path
4
+ import json
5
+ import csv
6
+ import pandas as pd
7
+ from tqdm import tqdm
8
  api_token = os.getenv("HF_TOKEN")
9
 
10
 
11
  from langchain_community.vectorstores import FAISS
12
+ from langchain_community.document_loaders import PyPDFLoader, TextLoader, CSVLoader, JSONLoader
13
  from langchain.text_splitter import RecursiveCharacterTextSplitter
14
  from langchain_community.vectorstores import Chroma
15
  from langchain.chains import ConversationalRetrievalChain
 
25
  list_llm = ["meta-llama/Meta-Llama-3-8B-Instruct", "mistralai/Mistral-7B-Instruct-v0.2"]
26
  list_llm_simple = [os.path.basename(llm) for llm in list_llm]
27
 
28
+ # Load and split documents of various types
29
+ def load_doc(list_file_path, progress=gr.Progress()):
30
+ doc_splits = []
31
+
32
+ progress(0, desc="Preparing to load documents")
33
+ total_files = len(list_file_path)
34
+
35
+ for i, file_path in enumerate(list_file_path):
36
+ progress((i/total_files) * 0.5, desc=f"Loading {Path(file_path).name}")
37
+ file_ext = Path(file_path).suffix.lower()
38
+
39
+ try:
40
+ # PDF documents
41
+ if file_ext == '.pdf':
42
+ loader = PyPDFLoader(file_path)
43
+ pages = loader.load()
44
+ doc_splits.extend(split_documents(pages))
45
+
46
+ # Text-based documents
47
+ elif file_ext in ['.txt', '.md', '.py', '.js', '.html', '.css']:
48
+ loader = TextLoader(file_path)
49
+ documents = loader.load()
50
+ doc_splits.extend(split_documents(documents))
51
+
52
+ # CSV files
53
+ elif file_ext == '.csv':
54
+ loader = CSVLoader(file_path)
55
+ documents = loader.load()
56
+ doc_splits.extend(split_documents(documents))
57
+
58
+ # JSON files
59
+ elif file_ext in ['.json', '.jsonl']:
60
+ # For JSON, we need to determine if it's JSON or JSONL
61
+ with open(file_path, 'r') as f:
62
+ content = f.read().strip()
63
+ if content.startswith('[') or content.startswith('{'):
64
+ # Regular JSON
65
+ loader = JSONLoader(
66
+ file_path=file_path,
67
+ jq_schema='.',
68
+ text_content=False
69
+ )
70
+ documents = loader.load()
71
+ doc_splits.extend(split_documents(documents))
72
+ else:
73
+ # JSONL - process line by line
74
+ documents = []
75
+ with open(file_path, 'r') as f:
76
+ for line in f:
77
+ if line.strip():
78
+ try:
79
+ json_obj = json.loads(line)
80
+ text = json.dumps(json_obj)
81
+ documents.append(text)
82
+ except json.JSONDecodeError:
83
+ continue
84
+
85
+ text_splitter = RecursiveCharacterTextSplitter(
86
+ chunk_size=1024,
87
+ chunk_overlap=64
88
+ )
89
+ doc_splits.extend(text_splitter.create_documents(documents))
90
+ except Exception as e:
91
+ print(f"Error processing {file_path}: {str(e)}")
92
+ continue
93
+
94
+ progress(0.5 + (i/total_files) * 0.5, desc=f"Processed {Path(file_path).name}")
95
+
96
  return doc_splits
97
 
98
+ # Helper function to split documents
99
+ def split_documents(documents):
100
+ text_splitter = RecursiveCharacterTextSplitter(
101
+ chunk_size=1024,
102
+ chunk_overlap=64
103
+ )
104
+ return text_splitter.split_documents(documents)
105
+
106
  # Create vector database
107
+ def create_db(splits, progress=gr.Progress()):
108
+ progress(0, desc="Creating vector database")
109
  embeddings = HuggingFaceEmbeddings()
110
+
111
+ # Create vectors with progress bar
112
+ total_chunks = len(splits)
113
+ vectordb = FAISS.from_documents(
114
+ documents=splits,
115
+ embedding=embeddings
116
+ )
117
+
118
+ progress(1.0, desc="Vector database creation complete")
119
  return vectordb
120
 
121
 
122
  # Initialize langchain LLM chain
 
123
  def initialize_llmchain(llm_model, temperature, max_tokens, top_k, vector_db, progress=gr.Progress()):
124
+ progress(0, desc=f"Initializing {llm_model}")
125
+
126
  if llm_model == "meta-llama/Meta-Llama-3-8B-Instruct":
127
  llm = HuggingFaceEndpoint(
128
  repo_id=llm_model,
129
+ huggingfacehub_api_token=api_token,
130
+ temperature=temperature,
131
+ max_new_tokens=max_tokens,
132
+ top_k=top_k,
133
  )
134
  else:
135
  llm = HuggingFaceEndpoint(
136
+ huggingfacehub_api_token=api_token,
137
  repo_id=llm_model,
138
+ temperature=temperature,
139
+ max_new_tokens=max_tokens,
140
+ top_k=top_k,
141
  )
142
 
143
+ progress(0.5, desc="Setting up memory and retriever")
144
+
145
  memory = ConversationBufferMemory(
146
  memory_key="chat_history",
147
  output_key='answer',
148
  return_messages=True
149
  )
150
 
151
+ retriever = vector_db.as_retriever()
152
  qa_chain = ConversationalRetrievalChain.from_llm(
153
  llm,
154
  retriever=retriever,
 
157
  return_source_documents=True,
158
  verbose=False,
159
  )
160
+
161
+ progress(1.0, desc="LLM chain initialized")
162
  return qa_chain
163
 
164
  # Initialize database
 
165
  def initialize_database(list_file_obj, progress=gr.Progress()):
166
  # Create a list of documents (when valid)
167
  list_file_path = [x.name for x in list_file_obj if x is not None]
168
+
169
+ if not list_file_path:
170
+ return None, "No valid files uploaded. Please upload at least one file."
171
+
172
  # Load document and create splits
173
+ doc_splits = load_doc(list_file_path, progress)
174
+
175
+ if not doc_splits:
176
+ return None, "Could not extract any text from the uploaded files."
177
+
178
  # Create or load vector database
179
+ vector_db = create_db(doc_splits, progress)
180
+
181
+ # Count documents by type
182
+ file_types = {}
183
+ for path in list_file_path:
184
+ ext = Path(path).suffix.lower()
185
+ file_types[ext] = file_types.get(ext, 0) + 1
186
+
187
+ file_type_summary = ", ".join([f"{count} {ext}" for ext, count in file_types.items()])
188
+
189
+ return vector_db, f"Database created with {len(doc_splits)} chunks from {len(list_file_path)} files ({file_type_summary})!"
190
 
191
  # Initialize LLM
 
192
  def initialize_LLM(llm_option, llm_temperature, max_tokens, top_k, vector_db, progress=gr.Progress()):
193
+ if vector_db is None:
194
+ return None, "Please create a vector database first!"
195
+
196
  llm_name = list_llm[llm_option]
197
+ print("llm_name: ", llm_name)
198
  qa_chain = initialize_llmchain(llm_name, llm_temperature, max_tokens, top_k, vector_db, progress)
199
+ return qa_chain, f"QA chain initialized with {llm_name}. Chatbot is ready!"
200
 
201
 
202
  def format_chat_history(message, chat_history):
 
206
  formatted_chat_history.append(f"Assistant: {bot_message}")
207
  return formatted_chat_history
208
 
 
209
  def conversation(qa_chain, message, history):
210
+ if qa_chain is None:
211
+ return None, gr.update(value=""), history, "Please initialize the chatbot first!", 0, "", 0, "", 0
212
+
213
  formatted_chat_history = format_chat_history(message, history)
214
  # Generate response using QA chain
215
  response = qa_chain.invoke({"question": message, "chat_history": formatted_chat_history})
216
  response_answer = response["answer"]
217
  if response_answer.find("Helpful Answer:") != -1:
218
  response_answer = response_answer.split("Helpful Answer:")[-1]
219
+
220
  response_sources = response["source_documents"]
221
+
222
+ # Handle source documents
223
+ source_contents = ["", "", ""]
224
+ source_pages = [0, 0, 0]
225
+
226
+ for i, source in enumerate(response_sources[:3]):
227
+ source_contents[i] = source.page_content.strip()
228
+ # Check if the metadata contains a page number
229
+ if "page" in source.metadata:
230
+ source_pages[i] = source.metadata["page"] + 1
231
+ elif "source" in source.metadata:
232
+ source_pages[i] = 1
233
+ source_contents[i] = f"From: {source.metadata['source']}\n{source_contents[i]}"
234
+
235
  # Append user message and response to chat history
236
  new_history = history + [(message, response_answer)]
 
237
 
238
+ return qa_chain, gr.update(value=""), new_history, source_contents[0], source_pages[0], source_contents[1], source_pages[1], source_contents[2], source_pages[2]
239
+
240
+
241
+ def get_file_icon(file_path):
242
+ """Return an appropriate emoji icon based on file extension"""
243
+ ext = Path(file_path).suffix.lower()
244
+ icons = {
245
+ '.pdf': '📄',
246
+ '.txt': '📝',
247
+ '.md': '📋',
248
+ '.py': '🐍',
249
+ '.js': '⚙️',
250
+ '.json': '📊',
251
+ '.jsonl': '📊',
252
+ '.csv': '📈',
253
+ '.html': '🌐',
254
+ '.css': '🎨',
255
+ }
256
+ return icons.get(ext, '📁')
257
 
258
+
259
+ def display_file_list(file_obj):
260
+ if not file_obj:
261
+ return "No files uploaded yet"
262
+
263
+ file_list = [f"{get_file_icon(x.name)} {Path(x.name).name}" for x in file_obj if x is not None]
264
+ return "\n".join(file_list)
265
 
266
 
267
  def demo():
268
+ with gr.Blocks(theme=gr.themes.Soft(primary_hue="indigo", secondary_hue="blue", neutral_hue="sky")) as demo:
 
269
  vector_db = gr.State()
270
  qa_chain = gr.State()
271
+
272
+ gr.HTML("<center><h1>📚 Enhanced RAG Chatbot</h1></center>")
273
+ gr.Markdown("""<b>Query your documents!</b> This enhanced AI agent performs retrieval augmented generation (RAG) on various document types
274
+ including PDFs, text files, markdown, code files, and structured data (CSV, JSON, JSONL). <b>Please do not upload confidential documents.</b>
275
  """)
276
+
277
  with gr.Row():
278
+ with gr.Column(scale=86):
279
+ gr.Markdown("<b>Step 1 - Upload Documents and Initialize RAG Pipeline</b>")
280
  with gr.Row():
281
+ with gr.Column(scale=7):
282
+ document = gr.Files(
283
+ height=300,
284
+ file_count="multiple",
285
+ file_types=[".pdf", ".txt", ".md", ".py", ".js", ".json", ".jsonl", ".csv", ".html", ".css"],
286
+ interactive=True,
287
+ label="Upload Documents"
288
+ )
289
+ with gr.Column(scale=3):
290
+ file_list = gr.Textbox(
291
+ label="Uploaded Files",
292
+ value="No files uploaded yet",
293
+ interactive=False,
294
+ lines=12
295
+ )
296
+ document.upload(
297
+ display_file_list,
298
+ inputs=[document],
299
+ outputs=[file_list]
300
+ )
301
+
302
  with gr.Row():
303
+ db_btn = gr.Button("Create Vector Database", variant="primary")
304
+
305
  with gr.Row():
306
+ db_progress = gr.Textbox(
307
+ value="Not initialized",
308
+ show_label=False,
309
+ container=True
310
+ )
311
+
312
+ gr.Markdown("<b>Step 2 - Select LLM and Parameters</b>")
313
  with gr.Row():
314
+ llm_btn = gr.Radio(
315
+ list_llm_simple,
316
+ label="Available LLMs",
317
+ value=list_llm_simple[0],
318
+ type="index"
319
+ )
320
+
321
  with gr.Row():
322
+ with gr.Accordion("LLM Parameters", open=False):
323
  with gr.Row():
324
+ slider_temperature = gr.Slider(
325
+ minimum=0.01,
326
+ maximum=1.0,
327
+ value=0.5,
328
+ step=0.1,
329
+ label="Temperature",
330
+ info="Controls randomness in generation",
331
+ interactive=True
332
+ )
333
  with gr.Row():
334
+ slider_maxtokens = gr.Slider(
335
+ minimum=128,
336
+ maximum=9192,
337
+ value=4096,
338
+ step=128,
339
+ label="Max New Tokens",
340
+ info="Maximum tokens to generate",
341
+ interactive=True
342
+ )
343
  with gr.Row():
344
+ slider_topk = gr.Slider(
345
+ minimum=1,
346
+ maximum=10,
347
+ value=3,
348
+ step=1,
349
+ label="Top-k",
350
+ info="Number of tokens to consider",
351
+ interactive=True
352
+ )
353
+
354
  with gr.Row():
355
+ qachain_btn = gr.Button("Initialize Chatbot", variant="primary")
356
+
357
  with gr.Row():
358
+ llm_progress = gr.Textbox(
359
+ value="Not initialized",
360
+ show_label=False,
361
+ container=True
362
+ )
363
 
364
+ with gr.Column(scale=200):
365
+ gr.Markdown("<b>Step 3 - Chat with Your Documents</b>")
366
  chatbot = gr.Chatbot(height=505)
367
+
368
+ with gr.Accordion("Relevant Context from Documents", open=False):
369
  with gr.Row():
370
  doc_source1 = gr.Textbox(label="Reference 1", lines=2, container=True, scale=20)
371
  source1_page = gr.Number(label="Page", scale=1)
 
375
  with gr.Row():
376
  doc_source3 = gr.Textbox(label="Reference 3", lines=2, container=True, scale=20)
377
  source3_page = gr.Number(label="Page", scale=1)
378
+
379
  with gr.Row():
380
+ msg = gr.Textbox(
381
+ placeholder="Ask a question about your documents...",
382
+ container=True,
383
+ lines=2
384
+ )
385
+
386
  with gr.Row():
387
+ submit_btn = gr.Button("Submit", variant="primary")
388
  clear_btn = gr.ClearButton([msg, chatbot], value="Clear")
389
 
390
  # Preprocessing events
391
+ db_btn.click(
392
+ initialize_database,
393
+ inputs=[document],
394
+ outputs=[vector_db, db_progress]
395
+ )
396
+
397
+ qachain_btn.click(
398
+ initialize_LLM,
399
+ inputs=[llm_btn, slider_temperature, slider_maxtokens, slider_topk, vector_db],
400
+ outputs=[qa_chain, llm_progress]
401
+ ).then(
402
+ lambda:[None,"",0,"",0,"",0],
403
+ inputs=None,
404
+ outputs=[chatbot, doc_source1, source1_page, doc_source2, source2_page, doc_source3, source3_page],
405
+ queue=False
406
+ )
407
 
408
  # Chatbot events
409
+ msg.submit(
410
+ conversation,
411
+ inputs=[qa_chain, msg, chatbot],
412
+ outputs=[qa_chain, msg, chatbot, doc_source1, source1_page, doc_source2, source2_page, doc_source3, source3_page],
413
+ queue=False
414
+ )
415
+
416
+ submit_btn.click(
417
+ conversation,
418
+ inputs=[qa_chain, msg, chatbot],
419
+ outputs=[qa_chain, msg, chatbot, doc_source1, source1_page, doc_source2, source2_page, doc_source3, source3_page],
420
+ queue=False
421
+ )
422
+
423
+ clear_btn.click(
424
+ lambda:[None,"",0,"",0,"",0],
425
+ inputs=None,
426
+ outputs=[chatbot, doc_source1, source1_page, doc_source2, source2_page, doc_source3, source3_page],
427
+ queue=False
428
+ )
429
+
430
  demo.queue().launch(debug=True)
431
 
432