MatteoScript commited on
Commit
209f685
·
verified ·
1 Parent(s): 1a34146

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +56 -3
app.py CHANGED
@@ -1,5 +1,6 @@
1
  import gradio as gr
2
  import os
 
3
  from langchain.document_loaders import PyPDFLoader
4
  from langchain.text_splitter import RecursiveCharacterTextSplitter
5
  from langchain.vectorstores import Chroma
@@ -9,30 +10,39 @@ from langchain.llms import HuggingFacePipeline
9
  from langchain.chains import ConversationChain
10
  from langchain.memory import ConversationBufferMemory
11
  from langchain.llms import HuggingFaceHub
 
12
  from pathlib import Path
13
  import chromadb
 
14
  from transformers import AutoTokenizer
15
  import transformers
16
  import torch
17
  import tqdm
18
  import accelerate
19
 
 
 
20
  llm_name0 = "mistralai/Mixtral-8x7B-Instruct-v0.1"
21
- list_llm = [llm_name0]
22
  list_llm_simple = [os.path.basename(llm) for llm in list_llm]
23
 
24
  # Load PDF document and create doc splits
25
  def load_doc(list_file_path, chunk_size, chunk_overlap):
 
 
 
26
  loaders = [PyPDFLoader(x) for x in list_file_path]
27
  pages = []
28
  for loader in loaders:
29
  pages.extend(loader.load())
 
30
  text_splitter = RecursiveCharacterTextSplitter(
31
  chunk_size = chunk_size,
32
  chunk_overlap = chunk_overlap)
33
  doc_splits = text_splitter.split_documents(pages)
34
  return doc_splits
35
 
 
36
  # Create vector database
37
  def create_db(splits, collection_name):
38
  embedding = HuggingFaceEmbeddings()
@@ -42,32 +52,38 @@ def create_db(splits, collection_name):
42
  embedding=embedding,
43
  client=new_client,
44
  collection_name=collection_name,
 
45
  )
46
  return vectordb
47
 
 
48
  # Load vector database
49
  def load_db():
50
  embedding = HuggingFaceEmbeddings()
51
  vectordb = Chroma(
 
52
  embedding_function=embedding)
53
  return vectordb
54
 
 
55
  # Initialize langchain LLM chain
56
  def initialize_llmchain(llm_model, temperature, max_tokens, top_k, vector_db, progress=gr.Progress()):
57
  progress(0.1, desc="Initializing HF tokenizer...")
58
  progress(0.5, desc="Initializing HF Hub...")
 
59
  if llm_model == "mistralai/Mixtral-8x7B-Instruct-v0.1":
60
  llm = HuggingFaceHub(
61
  repo_id=llm_model,
62
  model_kwargs={"temperature": temperature, "max_new_tokens": max_tokens, "top_k": top_k, "load_in_8bit": True}
63
  )
 
64
  progress(0.75, desc="Defining buffer memory...")
65
  memory = ConversationBufferMemory(
66
  memory_key="chat_history",
67
  output_key='answer',
68
  return_messages=True
69
  )
70
-
71
  retriever=vector_db.as_retriever()
72
  progress(0.8, desc="Defining retrieval chain...")
73
  qa_chain = ConversationalRetrievalChain.from_llm(
@@ -75,27 +91,42 @@ def initialize_llmchain(llm_model, temperature, max_tokens, top_k, vector_db, pr
75
  retriever=retriever,
76
  chain_type="stuff",
77
  memory=memory,
 
78
  return_source_documents=True,
 
 
79
  )
80
  progress(0.9, desc="Done!")
81
  return qa_chain
82
 
 
 
83
  def initialize_database(list_file_obj, chunk_size, chunk_overlap, progress=gr.Progress()):
 
 
84
  list_file_path = [x.name for x in list_file_obj if x is not None]
85
  collection_name = Path(list_file_path[0]).stem
 
 
86
  progress(0.25, desc="Loading document...")
 
87
  doc_splits = load_doc(list_file_path, chunk_size, chunk_overlap)
 
88
  progress(0.5, desc="Generating vector database...")
 
89
  vector_db = create_db(doc_splits, collection_name)
90
  progress(0.9, desc="Done!")
91
  return vector_db, collection_name, "Complete!"
92
 
 
93
  def initialize_LLM(llm_option, llm_temperature, max_tokens, top_k, vector_db, progress=gr.Progress()):
 
94
  llm_name = list_llm[llm_option]
95
  print("llm_name: ",llm_name)
96
  qa_chain = initialize_llmchain(llm_name, llm_temperature, max_tokens, top_k, vector_db, progress)
97
  return qa_chain, "Complete!"
98
 
 
99
  def format_chat_history(message, chat_history):
100
  formatted_chat_history = []
101
  for user_message, bot_message in chat_history:
@@ -103,25 +134,39 @@ def format_chat_history(message, chat_history):
103
  formatted_chat_history.append(f"Assistant: {bot_message}")
104
  return formatted_chat_history
105
 
 
106
  def conversation(qa_chain, message, history):
107
  formatted_chat_history = format_chat_history(message, history)
 
 
 
108
  response = qa_chain({"question": message, "chat_history": formatted_chat_history})
109
  response_answer = response["answer"]
110
  response_sources = response["source_documents"]
111
  response_source1 = response_sources[0].page_content.strip()
112
  response_source2 = response_sources[1].page_content.strip()
 
113
  response_source1_page = response_sources[0].metadata["page"] + 1
114
  response_source2_page = response_sources[1].metadata["page"] + 1
 
 
 
 
115
  new_history = history + [(message, response_answer)]
 
116
  return qa_chain, gr.update(value=""), new_history, response_source1, response_source1_page, response_source2, response_source2_page
117
 
 
118
  def upload_file(file_obj):
119
  list_file_path = []
120
  for idx, file in enumerate(file_obj):
121
  file_path = file_obj.name
122
  list_file_path.append(file_path)
 
 
123
  return list_file_path
124
 
 
125
  def demo():
126
  with gr.Blocks(theme="base") as demo:
127
  vector_db = gr.State()
@@ -129,10 +174,16 @@ def demo():
129
  collection_name = gr.State()
130
 
131
  gr.Markdown(
132
- """<center><h2>PDF-based chatbot (powered by LangChain and open-source LLMs)</center></h2>""")
 
 
 
 
 
133
  with gr.Tab("Step 1 - Document pre-processing"):
134
  with gr.Row():
135
  document = gr.Files(height=100, file_count="multiple", file_types=["pdf"], interactive=True, label="Upload your PDF documents (single or multiple)")
 
136
  with gr.Row():
137
  db_btn = gr.Radio(["ChromaDB"], label="Vector database type", value = "ChromaDB", type="index", info="Choose your vector database")
138
  with gr.Accordion("Advanced options - Document text splitter", open=False):
@@ -177,6 +228,7 @@ def demo():
177
  clear_btn = gr.ClearButton([msg, chatbot])
178
 
179
  # Preprocessing events
 
180
  db_btn.click(initialize_database, \
181
  inputs=[document, slider_chunk_size, slider_chunk_overlap], \
182
  outputs=[vector_db, collection_name, db_progress])
@@ -202,5 +254,6 @@ def demo():
202
  queue=False)
203
  demo.queue().launch(debug=True)
204
 
 
205
  if __name__ == "__main__":
206
  demo()
 
1
  import gradio as gr
2
  import os
3
+
4
  from langchain.document_loaders import PyPDFLoader
5
  from langchain.text_splitter import RecursiveCharacterTextSplitter
6
  from langchain.vectorstores import Chroma
 
10
  from langchain.chains import ConversationChain
11
  from langchain.memory import ConversationBufferMemory
12
  from langchain.llms import HuggingFaceHub
13
+
14
  from pathlib import Path
15
  import chromadb
16
+
17
  from transformers import AutoTokenizer
18
  import transformers
19
  import torch
20
  import tqdm
21
  import accelerate
22
 
23
+
24
+
25
  llm_name0 = "mistralai/Mixtral-8x7B-Instruct-v0.1"
26
+ list_llm = [llm_name0, llm_name1, llm_name2, llm_name3, llm_name4, llm_name5, llm_name6, llm_name7, llm_name8]
27
  list_llm_simple = [os.path.basename(llm) for llm in list_llm]
28
 
29
  # Load PDF document and create doc splits
30
  def load_doc(list_file_path, chunk_size, chunk_overlap):
31
+ # Processing for one document only
32
+ # loader = PyPDFLoader(file_path)
33
+ # pages = loader.load()
34
  loaders = [PyPDFLoader(x) for x in list_file_path]
35
  pages = []
36
  for loader in loaders:
37
  pages.extend(loader.load())
38
+ # text_splitter = RecursiveCharacterTextSplitter(chunk_size = 600, chunk_overlap = 50)
39
  text_splitter = RecursiveCharacterTextSplitter(
40
  chunk_size = chunk_size,
41
  chunk_overlap = chunk_overlap)
42
  doc_splits = text_splitter.split_documents(pages)
43
  return doc_splits
44
 
45
+
46
  # Create vector database
47
  def create_db(splits, collection_name):
48
  embedding = HuggingFaceEmbeddings()
 
52
  embedding=embedding,
53
  client=new_client,
54
  collection_name=collection_name,
55
+ # persist_directory=default_persist_directory
56
  )
57
  return vectordb
58
 
59
+
60
  # Load vector database
61
  def load_db():
62
  embedding = HuggingFaceEmbeddings()
63
  vectordb = Chroma(
64
+ # persist_directory=default_persist_directory,
65
  embedding_function=embedding)
66
  return vectordb
67
 
68
+
69
  # Initialize langchain LLM chain
70
  def initialize_llmchain(llm_model, temperature, max_tokens, top_k, vector_db, progress=gr.Progress()):
71
  progress(0.1, desc="Initializing HF tokenizer...")
72
  progress(0.5, desc="Initializing HF Hub...")
73
+ # URL: https://github.com/langchain-ai/langchain/issues/6080
74
  if llm_model == "mistralai/Mixtral-8x7B-Instruct-v0.1":
75
  llm = HuggingFaceHub(
76
  repo_id=llm_model,
77
  model_kwargs={"temperature": temperature, "max_new_tokens": max_tokens, "top_k": top_k, "load_in_8bit": True}
78
  )
79
+
80
  progress(0.75, desc="Defining buffer memory...")
81
  memory = ConversationBufferMemory(
82
  memory_key="chat_history",
83
  output_key='answer',
84
  return_messages=True
85
  )
86
+ # retriever=vector_db.as_retriever(search_type="similarity", search_kwargs={'k': 3})
87
  retriever=vector_db.as_retriever()
88
  progress(0.8, desc="Defining retrieval chain...")
89
  qa_chain = ConversationalRetrievalChain.from_llm(
 
91
  retriever=retriever,
92
  chain_type="stuff",
93
  memory=memory,
94
+ # combine_docs_chain_kwargs={"prompt": your_prompt})
95
  return_source_documents=True,
96
+ # return_generated_question=True,
97
+ # verbose=True,
98
  )
99
  progress(0.9, desc="Done!")
100
  return qa_chain
101
 
102
+
103
+ # Initialize database
104
  def initialize_database(list_file_obj, chunk_size, chunk_overlap, progress=gr.Progress()):
105
+ # Create list of documents (when valid)
106
+ #file_path = file_obj.name
107
  list_file_path = [x.name for x in list_file_obj if x is not None]
108
  collection_name = Path(list_file_path[0]).stem
109
+ # print('list_file_path: ', list_file_path)
110
+ # print('Collection name: ', collection_name)
111
  progress(0.25, desc="Loading document...")
112
+ # Load document and create splits
113
  doc_splits = load_doc(list_file_path, chunk_size, chunk_overlap)
114
+ # Create or load Vector database
115
  progress(0.5, desc="Generating vector database...")
116
+ # global vector_db
117
  vector_db = create_db(doc_splits, collection_name)
118
  progress(0.9, desc="Done!")
119
  return vector_db, collection_name, "Complete!"
120
 
121
+
122
  def initialize_LLM(llm_option, llm_temperature, max_tokens, top_k, vector_db, progress=gr.Progress()):
123
+ # print("llm_option",llm_option)
124
  llm_name = list_llm[llm_option]
125
  print("llm_name: ",llm_name)
126
  qa_chain = initialize_llmchain(llm_name, llm_temperature, max_tokens, top_k, vector_db, progress)
127
  return qa_chain, "Complete!"
128
 
129
+
130
  def format_chat_history(message, chat_history):
131
  formatted_chat_history = []
132
  for user_message, bot_message in chat_history:
 
134
  formatted_chat_history.append(f"Assistant: {bot_message}")
135
  return formatted_chat_history
136
 
137
+
138
  def conversation(qa_chain, message, history):
139
  formatted_chat_history = format_chat_history(message, history)
140
+ #print("formatted_chat_history",formatted_chat_history)
141
+
142
+ # Generate response using QA chain
143
  response = qa_chain({"question": message, "chat_history": formatted_chat_history})
144
  response_answer = response["answer"]
145
  response_sources = response["source_documents"]
146
  response_source1 = response_sources[0].page_content.strip()
147
  response_source2 = response_sources[1].page_content.strip()
148
+ # Langchain sources are zero-based
149
  response_source1_page = response_sources[0].metadata["page"] + 1
150
  response_source2_page = response_sources[1].metadata["page"] + 1
151
+ # print ('chat response: ', response_answer)
152
+ # print('DB source', response_sources)
153
+
154
+ # Append user message and response to chat history
155
  new_history = history + [(message, response_answer)]
156
+ # return gr.update(value=""), new_history, response_sources[0], response_sources[1]
157
  return qa_chain, gr.update(value=""), new_history, response_source1, response_source1_page, response_source2, response_source2_page
158
 
159
+
160
  def upload_file(file_obj):
161
  list_file_path = []
162
  for idx, file in enumerate(file_obj):
163
  file_path = file_obj.name
164
  list_file_path.append(file_path)
165
+ # print(file_path)
166
+ # initialize_database(file_path, progress)
167
  return list_file_path
168
 
169
+
170
  def demo():
171
  with gr.Blocks(theme="base") as demo:
172
  vector_db = gr.State()
 
174
  collection_name = gr.State()
175
 
176
  gr.Markdown(
177
+ """<center><h2>PDF-based chatbot (powered by LangChain and open-source LLMs)</center></h2>
178
+ <h3>Ask any questions about your PDF documents, along with follow-ups</h3>
179
+ <b>Note:</b> This AI assistant performs retrieval-augmented generation from your PDF documents. \
180
+ When generating answers, it takes past questions into account (via conversational memory), and includes document references for clarity purposes.</i>
181
+ <br><b>Warning:</b> This space uses the free CPU Basic hardware from Hugging Face. Some steps and LLM models used below (free inference endpoints) can take some time to generate an output.<br>
182
+ """)
183
  with gr.Tab("Step 1 - Document pre-processing"):
184
  with gr.Row():
185
  document = gr.Files(height=100, file_count="multiple", file_types=["pdf"], interactive=True, label="Upload your PDF documents (single or multiple)")
186
+ # upload_btn = gr.UploadButton("Loading document...", height=100, file_count="multiple", file_types=["pdf"], scale=1)
187
  with gr.Row():
188
  db_btn = gr.Radio(["ChromaDB"], label="Vector database type", value = "ChromaDB", type="index", info="Choose your vector database")
189
  with gr.Accordion("Advanced options - Document text splitter", open=False):
 
228
  clear_btn = gr.ClearButton([msg, chatbot])
229
 
230
  # Preprocessing events
231
+ #upload_btn.upload(upload_file, inputs=[upload_btn], outputs=[document])
232
  db_btn.click(initialize_database, \
233
  inputs=[document, slider_chunk_size, slider_chunk_overlap], \
234
  outputs=[vector_db, collection_name, db_progress])
 
254
  queue=False)
255
  demo.queue().launch(debug=True)
256
 
257
+
258
  if __name__ == "__main__":
259
  demo()