MatteoScript commited on
Commit
1a34146
·
verified ·
1 Parent(s): 878c0a1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +117 -50
app.py CHANGED
@@ -5,20 +5,35 @@ from langchain.text_splitter import RecursiveCharacterTextSplitter
5
  from langchain.vectorstores import Chroma
6
  from langchain.chains import ConversationalRetrievalChain
7
  from langchain.embeddings import HuggingFaceEmbeddings
 
 
 
8
  from langchain.llms import HuggingFaceHub
9
  from pathlib import Path
10
  import chromadb
 
 
 
 
 
11
 
12
- llm_names = ["mistralai/Mixtral-8x7B-Instruct-v0.1"]
13
- llm_names_simple = [os.path.basename(llm) for llm in llm_names]
 
14
 
 
15
  def load_doc(list_file_path, chunk_size, chunk_overlap):
16
  loaders = [PyPDFLoader(x) for x in list_file_path]
17
- pages = [loader.load() for loader in loaders]
18
- text_splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap)
 
 
 
 
19
  doc_splits = text_splitter.split_documents(pages)
20
  return doc_splits
21
 
 
22
  def create_db(splits, collection_name):
23
  embedding = HuggingFaceEmbeddings()
24
  new_client = chromadb.EphemeralClient()
@@ -30,26 +45,35 @@ def create_db(splits, collection_name):
30
  )
31
  return vectordb
32
 
 
33
  def load_db():
34
  embedding = HuggingFaceEmbeddings()
35
- vectordb = Chroma(embedding_function=embedding)
 
36
  return vectordb
37
 
 
38
  def initialize_llmchain(llm_model, temperature, max_tokens, top_k, vector_db, progress=gr.Progress()):
39
  progress(0.1, desc="Initializing HF tokenizer...")
40
  progress(0.5, desc="Initializing HF Hub...")
41
- model_kwargs = {"temperature": temperature, "max_new_tokens": max_tokens, "top_k": top_k}
42
  if llm_model == "mistralai/Mixtral-8x7B-Instruct-v0.1":
43
- model_kwargs["load_in_8bit"] = True
44
- llm = HuggingFaceHub(repo_id=llm_model, model_kwargs=model_kwargs)
 
 
45
  progress(0.75, desc="Defining buffer memory...")
46
- memory = ConversationBufferMemory(memory_key="chat_history", output_key='answer', return_messages=True)
47
- retriever = vector_db.as_retriever()
 
 
 
 
 
48
  progress(0.8, desc="Defining retrieval chain...")
49
  qa_chain = ConversationalRetrievalChain.from_llm(
50
  llm,
51
  retriever=retriever,
52
- chain_type="stuff",
53
  memory=memory,
54
  return_source_documents=True,
55
  )
@@ -67,14 +91,18 @@ def initialize_database(list_file_obj, chunk_size, chunk_overlap, progress=gr.Pr
67
  return vector_db, collection_name, "Complete!"
68
 
69
  def initialize_LLM(llm_option, llm_temperature, max_tokens, top_k, vector_db, progress=gr.Progress()):
70
- llm_name = llm_names[llm_option]
 
71
  qa_chain = initialize_llmchain(llm_name, llm_temperature, max_tokens, top_k, vector_db, progress)
72
  return qa_chain, "Complete!"
73
 
74
  def format_chat_history(message, chat_history):
75
- formatted_chat_history = [f"User: {user_message}\nAssistant: {bot_message}" for user_message, bot_message in chat_history]
 
 
 
76
  return formatted_chat_history
77
-
78
  def conversation(qa_chain, message, history):
79
  formatted_chat_history = format_chat_history(message, history)
80
  response = qa_chain({"question": message, "chat_history": formatted_chat_history})
@@ -86,9 +114,12 @@ def conversation(qa_chain, message, history):
86
  response_source2_page = response_sources[1].metadata["page"] + 1
87
  new_history = history + [(message, response_answer)]
88
  return qa_chain, gr.update(value=""), new_history, response_source1, response_source1_page, response_source2, response_source2_page
89
-
90
  def upload_file(file_obj):
91
- list_file_path = [file_obj.name for _ in file_obj]
 
 
 
92
  return list_file_path
93
 
94
  def demo():
@@ -96,43 +127,79 @@ def demo():
96
  vector_db = gr.State()
97
  qa_chain = gr.State()
98
  collection_name = gr.State()
99
-
100
- gr.Markdown("""<center><h2>ChatPDF</center></h2>""")
101
-
102
- with gr.Tab("Step 1 - Selezione PDF"):
103
- document = gr.Files(height=100, file_count="multiple", file_types=["pdf"], interactive=True, label="Upload your PDF documents (single or multiple)")
104
- db_btn = gr.Radio(["ChromaDB"], label="Vector database type", value="ChromaDB", type="index", info="Choose your vector database")
 
 
105
  with gr.Accordion("Advanced options - Document text splitter", open=False):
106
- slider_chunk_size = gr.Slider(minimum=100, maximum=1000, value=600, step=20, label="Chunk size", info="Chunk size", interactive=True)
107
- slider_chunk_overlap = gr.Slider(minimum=10, maximum=200, value=40, step=10, label="Chunk overlap", info="Chunk overlap", interactive=True)
108
- db_progress = gr.Textbox(label="Vector database initialization", value="None")
109
- db_btn.click(initialize_database, inputs=[document, slider_chunk_size, slider_chunk_overlap], outputs=[vector_db, collection_name, db_progress])
110
-
111
- with gr.Tab("Step 2 - Inizializzazione QA"):
112
- llm_btn = gr.Radio(llm_names_simple, label="LLM models", value=llm_names_simple[0], type="index", info="Choose your LLM model")
 
 
 
 
 
 
113
  with gr.Accordion("Advanced options - LLM model", open=False):
114
- slider_temperature = gr.Slider(minimum=0.0, maximum=1.0, value=0.7, step=0.1, label="Temperature", info="Model temperature", interactive=True)
115
- slider_maxtokens = gr.Slider(minimum=224, maximum=4096, value=1024, step=32, label="Max Tokens", info="Model max tokens", interactive=True)
116
- slider_topk = gr.Slider(minimum=1, maximum=10, value=3, step=1, label="top-k samples", info="Model top-k samples", interactive=True)
117
- llm_progress = gr.Textbox(value="None", label="QA chain initialization")
118
- qachain_btn = gr.Button("Initialize question-answering chain...")
119
- qachain_btn.click(initialize_LLM, inputs=[llm_btn, slider_temperature, slider_maxtokens, slider_topk, vector_db], outputs=[qa_chain, llm_progress]).then(lambda: [None, "", 0, "", 0], inputs=None, outputs=[chatbot, doc_source1, source1_page, doc_source2, source2_page], queue=False)
120
-
121
- with gr.Tab("Step 3 - Conversazione con Chatbot"):
 
 
 
 
122
  chatbot = gr.Chatbot(height=300)
123
- with gr.Accordion("Advanced - Document references", open=True):
124
- doc_source1 = gr.Textbox(label="Reference 1", lines=2, container=True, scale=20)
125
- source1_page = gr.Number(label="Page", scale=1)
126
- doc_source2 = gr.Textbox(label="Reference 2", lines=2, container=True, scale=20)
127
- source2_page = gr.Number(label="Page", scale=1)
128
- msg = gr.Textbox(placeholder="Type message", container=True)
129
- submit_btn = gr.Button("Submit")
130
- clear_btn = gr.ClearButton([msg, chatbot])
131
-
132
- msg.submit(conversation, inputs=[qa_chain, msg, chatbot], outputs=[qa_chain, msg, chatbot, doc_source1, source1_page, doc_source2, source2_page], queue=False)
133
- submit_btn.click(conversation, inputs=[qa_chain, msg, chatbot], outputs=[qa_chain, msg, chatbot, doc_source1, source1_page, doc_source2, source2_page], queue=False)
134
- clear_btn.click(lambda: [None, "", 0, "", 0], inputs=None, outputs=[chatbot, doc_source1, source1_page, doc_source2, source2_page], queue=False)
135
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
136
  demo.queue().launch(debug=True)
137
 
138
  if __name__ == "__main__":
 
5
  from langchain.vectorstores import Chroma
6
  from langchain.chains import ConversationalRetrievalChain
7
  from langchain.embeddings import HuggingFaceEmbeddings
8
+ from langchain.llms import HuggingFacePipeline
9
+ from langchain.chains import ConversationChain
10
+ from langchain.memory import ConversationBufferMemory
11
  from langchain.llms import HuggingFaceHub
12
  from pathlib import Path
13
  import chromadb
14
+ from transformers import AutoTokenizer
15
+ import transformers
16
+ import torch
17
+ import tqdm
18
+ import accelerate
19
 
20
+ llm_name0 = "mistralai/Mixtral-8x7B-Instruct-v0.1"
21
+ list_llm = [llm_name0]
22
+ list_llm_simple = [os.path.basename(llm) for llm in list_llm]
23
 
24
+ # Load PDF document and create doc splits
25
  def load_doc(list_file_path, chunk_size, chunk_overlap):
26
  loaders = [PyPDFLoader(x) for x in list_file_path]
27
+ pages = []
28
+ for loader in loaders:
29
+ pages.extend(loader.load())
30
+ text_splitter = RecursiveCharacterTextSplitter(
31
+ chunk_size = chunk_size,
32
+ chunk_overlap = chunk_overlap)
33
  doc_splits = text_splitter.split_documents(pages)
34
  return doc_splits
35
 
36
+ # Create vector database
37
  def create_db(splits, collection_name):
38
  embedding = HuggingFaceEmbeddings()
39
  new_client = chromadb.EphemeralClient()
 
45
  )
46
  return vectordb
47
 
48
+ # Load vector database
49
  def load_db():
50
  embedding = HuggingFaceEmbeddings()
51
+ vectordb = Chroma(
52
+ embedding_function=embedding)
53
  return vectordb
54
 
55
+ # Initialize langchain LLM chain
56
  def initialize_llmchain(llm_model, temperature, max_tokens, top_k, vector_db, progress=gr.Progress()):
57
  progress(0.1, desc="Initializing HF tokenizer...")
58
  progress(0.5, desc="Initializing HF Hub...")
 
59
  if llm_model == "mistralai/Mixtral-8x7B-Instruct-v0.1":
60
+ llm = HuggingFaceHub(
61
+ repo_id=llm_model,
62
+ model_kwargs={"temperature": temperature, "max_new_tokens": max_tokens, "top_k": top_k, "load_in_8bit": True}
63
+ )
64
  progress(0.75, desc="Defining buffer memory...")
65
+ memory = ConversationBufferMemory(
66
+ memory_key="chat_history",
67
+ output_key='answer',
68
+ return_messages=True
69
+ )
70
+
71
+ retriever=vector_db.as_retriever()
72
  progress(0.8, desc="Defining retrieval chain...")
73
  qa_chain = ConversationalRetrievalChain.from_llm(
74
  llm,
75
  retriever=retriever,
76
+ chain_type="stuff",
77
  memory=memory,
78
  return_source_documents=True,
79
  )
 
91
  return vector_db, collection_name, "Complete!"
92
 
93
  def initialize_LLM(llm_option, llm_temperature, max_tokens, top_k, vector_db, progress=gr.Progress()):
94
+ llm_name = list_llm[llm_option]
95
+ print("llm_name: ",llm_name)
96
  qa_chain = initialize_llmchain(llm_name, llm_temperature, max_tokens, top_k, vector_db, progress)
97
  return qa_chain, "Complete!"
98
 
99
  def format_chat_history(message, chat_history):
100
+ formatted_chat_history = []
101
+ for user_message, bot_message in chat_history:
102
+ formatted_chat_history.append(f"User: {user_message}")
103
+ formatted_chat_history.append(f"Assistant: {bot_message}")
104
  return formatted_chat_history
105
+
106
  def conversation(qa_chain, message, history):
107
  formatted_chat_history = format_chat_history(message, history)
108
  response = qa_chain({"question": message, "chat_history": formatted_chat_history})
 
114
  response_source2_page = response_sources[1].metadata["page"] + 1
115
  new_history = history + [(message, response_answer)]
116
  return qa_chain, gr.update(value=""), new_history, response_source1, response_source1_page, response_source2, response_source2_page
117
+
118
  def upload_file(file_obj):
119
+ list_file_path = []
120
+ for idx, file in enumerate(file_obj):
121
+ file_path = file_obj.name
122
+ list_file_path.append(file_path)
123
  return list_file_path
124
 
125
  def demo():
 
127
  vector_db = gr.State()
128
  qa_chain = gr.State()
129
  collection_name = gr.State()
130
+
131
+ gr.Markdown(
132
+ """<center><h2>PDF-based chatbot (powered by LangChain and open-source LLMs)</center></h2>""")
133
+ with gr.Tab("Step 1 - Document pre-processing"):
134
+ with gr.Row():
135
+ document = gr.Files(height=100, file_count="multiple", file_types=["pdf"], interactive=True, label="Upload your PDF documents (single or multiple)")
136
+ with gr.Row():
137
+ db_btn = gr.Radio(["ChromaDB"], label="Vector database type", value = "ChromaDB", type="index", info="Choose your vector database")
138
  with gr.Accordion("Advanced options - Document text splitter", open=False):
139
+ with gr.Row():
140
+ slider_chunk_size = gr.Slider(minimum = 100, maximum = 1000, value=600, step=20, label="Chunk size", info="Chunk size", interactive=True)
141
+ with gr.Row():
142
+ slider_chunk_overlap = gr.Slider(minimum = 10, maximum = 200, value=40, step=10, label="Chunk overlap", info="Chunk overlap", interactive=True)
143
+ with gr.Row():
144
+ db_progress = gr.Textbox(label="Vector database initialization", value="None")
145
+ with gr.Row():
146
+ db_btn = gr.Button("Generate vector database...")
147
+
148
+ with gr.Tab("Step 2 - QA chain initialization"):
149
+ with gr.Row():
150
+ llm_btn = gr.Radio(list_llm_simple, \
151
+ label="LLM models", value = list_llm_simple[0], type="index", info="Choose your LLM model")
152
  with gr.Accordion("Advanced options - LLM model", open=False):
153
+ with gr.Row():
154
+ slider_temperature = gr.Slider(minimum = 0.0, maximum = 1.0, value=0.7, step=0.1, label="Temperature", info="Model temperature", interactive=True)
155
+ with gr.Row():
156
+ slider_maxtokens = gr.Slider(minimum = 224, maximum = 4096, value=1024, step=32, label="Max Tokens", info="Model max tokens", interactive=True)
157
+ with gr.Row():
158
+ slider_topk = gr.Slider(minimum = 1, maximum = 10, value=3, step=1, label="top-k samples", info="Model top-k samples", interactive=True)
159
+ with gr.Row():
160
+ llm_progress = gr.Textbox(value="None",label="QA chain initialization")
161
+ with gr.Row():
162
+ qachain_btn = gr.Button("Initialize question-answering chain...")
163
+
164
+ with gr.Tab("Step 3 - Conversation with chatbot"):
165
  chatbot = gr.Chatbot(height=300)
166
+ with gr.Accordion("Advanced - Document references", open=False):
167
+ with gr.Row():
168
+ doc_source1 = gr.Textbox(label="Reference 1", lines=2, container=True, scale=20)
169
+ source1_page = gr.Number(label="Page", scale=1)
170
+ with gr.Row():
171
+ doc_source2 = gr.Textbox(label="Reference 2", lines=2, container=True, scale=20)
172
+ source2_page = gr.Number(label="Page", scale=1)
173
+ with gr.Row():
174
+ msg = gr.Textbox(placeholder="Type message", container=True)
175
+ with gr.Row():
176
+ submit_btn = gr.Button("Submit")
177
+ clear_btn = gr.ClearButton([msg, chatbot])
178
+
179
+ # Preprocessing events
180
+ db_btn.click(initialize_database, \
181
+ inputs=[document, slider_chunk_size, slider_chunk_overlap], \
182
+ outputs=[vector_db, collection_name, db_progress])
183
+ qachain_btn.click(initialize_LLM, \
184
+ inputs=[llm_btn, slider_temperature, slider_maxtokens, slider_topk, vector_db], \
185
+ outputs=[qa_chain, llm_progress]).then(lambda:[None,"",0,"",0], \
186
+ inputs=None, \
187
+ outputs=[chatbot, doc_source1, source1_page, doc_source2, source2_page], \
188
+ queue=False)
189
+
190
+ # Chatbot events
191
+ msg.submit(conversation, \
192
+ inputs=[qa_chain, msg, chatbot], \
193
+ outputs=[qa_chain, msg, chatbot, doc_source1, source1_page, doc_source2, source2_page], \
194
+ queue=False)
195
+ submit_btn.click(conversation, \
196
+ inputs=[qa_chain, msg, chatbot], \
197
+ outputs=[qa_chain, msg, chatbot, doc_source1, source1_page, doc_source2, source2_page], \
198
+ queue=False)
199
+ clear_btn.click(lambda:[None,"",0,"",0], \
200
+ inputs=None, \
201
+ outputs=[chatbot, doc_source1, source1_page, doc_source2, source2_page], \
202
+ queue=False)
203
  demo.queue().launch(debug=True)
204
 
205
  if __name__ == "__main__":