Spaces:
Sleeping
Sleeping
File size: 7,777 Bytes
3ec9224 5be8df6 1ef8d7c 878c0a1 5be8df6 878c0a1 5be8df6 1ef8d7c 5be8df6 1ef8d7c 5be8df6 1ef8d7c 5be8df6 878c0a1 5be8df6 878c0a1 88fa380 878c0a1 eb94a8f 878c0a1 5be8df6 878c0a1 5be8df6 9733941 5be8df6 00bd139 5be8df6 1ef8d7c 5be8df6 1ef8d7c 5be8df6 1ef8d7c 5be8df6 878c0a1 00bd139 5be8df6 878c0a1 5be8df6 00bd139 5be8df6 9733941 00bd139 5be8df6 878c0a1 5be8df6 3ca2785 00bd139 1ef8d7c 878c0a1 a25f0eb 878c0a1 5be8df6 878c0a1 14155e5 878c0a1 5be8df6 323ccbe 5be8df6 878c0a1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 |
import gradio as gr
import os
from langchain.document_loaders import PyPDFLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.vectorstores import Chroma
from langchain.chains import ConversationalRetrievalChain
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.llms import HuggingFaceHub
from pathlib import Path
import chromadb
llm_names = ["mistralai/Mixtral-8x7B-Instruct-v0.1"]
llm_names_simple = [os.path.basename(llm) for llm in llm_names]
def load_doc(list_file_path, chunk_size, chunk_overlap):
loaders = [PyPDFLoader(x) for x in list_file_path]
pages = [loader.load() for loader in loaders]
text_splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap)
doc_splits = text_splitter.split_documents(pages)
return doc_splits
def create_db(splits, collection_name):
embedding = HuggingFaceEmbeddings()
new_client = chromadb.EphemeralClient()
vectordb = Chroma.from_documents(
documents=splits,
embedding=embedding,
client=new_client,
collection_name=collection_name,
)
return vectordb
def load_db():
embedding = HuggingFaceEmbeddings()
vectordb = Chroma(embedding_function=embedding)
return vectordb
def initialize_llmchain(llm_model, temperature, max_tokens, top_k, vector_db, progress=gr.Progress()):
progress(0.1, desc="Initializing HF tokenizer...")
progress(0.5, desc="Initializing HF Hub...")
model_kwargs = {"temperature": temperature, "max_new_tokens": max_tokens, "top_k": top_k}
if llm_model == "mistralai/Mixtral-8x7B-Instruct-v0.1":
model_kwargs["load_in_8bit"] = True
llm = HuggingFaceHub(repo_id=llm_model, model_kwargs=model_kwargs)
progress(0.75, desc="Defining buffer memory...")
memory = ConversationBufferMemory(memory_key="chat_history", output_key='answer', return_messages=True)
retriever = vector_db.as_retriever()
progress(0.8, desc="Defining retrieval chain...")
qa_chain = ConversationalRetrievalChain.from_llm(
llm,
retriever=retriever,
chain_type="stuff",
memory=memory,
return_source_documents=True,
)
progress(0.9, desc="Done!")
return qa_chain
def initialize_database(list_file_obj, chunk_size, chunk_overlap, progress=gr.Progress()):
list_file_path = [x.name for x in list_file_obj if x is not None]
collection_name = Path(list_file_path[0]).stem
progress(0.25, desc="Loading document...")
doc_splits = load_doc(list_file_path, chunk_size, chunk_overlap)
progress(0.5, desc="Generating vector database...")
vector_db = create_db(doc_splits, collection_name)
progress(0.9, desc="Done!")
return vector_db, collection_name, "Complete!"
def initialize_LLM(llm_option, llm_temperature, max_tokens, top_k, vector_db, progress=gr.Progress()):
llm_name = llm_names[llm_option]
qa_chain = initialize_llmchain(llm_name, llm_temperature, max_tokens, top_k, vector_db, progress)
return qa_chain, "Complete!"
def format_chat_history(message, chat_history):
formatted_chat_history = [f"User: {user_message}\nAssistant: {bot_message}" for user_message, bot_message in chat_history]
return formatted_chat_history
def conversation(qa_chain, message, history):
formatted_chat_history = format_chat_history(message, history)
response = qa_chain({"question": message, "chat_history": formatted_chat_history})
response_answer = response["answer"]
response_sources = response["source_documents"]
response_source1 = response_sources[0].page_content.strip()
response_source2 = response_sources[1].page_content.strip()
response_source1_page = response_sources[0].metadata["page"] + 1
response_source2_page = response_sources[1].metadata["page"] + 1
new_history = history + [(message, response_answer)]
return qa_chain, gr.update(value=""), new_history, response_source1, response_source1_page, response_source2, response_source2_page
def upload_file(file_obj):
list_file_path = [file_obj.name for _ in file_obj]
return list_file_path
def demo():
with gr.Blocks(theme="base") as demo:
vector_db = gr.State()
qa_chain = gr.State()
collection_name = gr.State()
gr.Markdown("""<center><h2>ChatPDF</center></h2>""")
with gr.Tab("Step 1 - Selezione PDF"):
document = gr.Files(height=100, file_count="multiple", file_types=["pdf"], interactive=True, label="Upload your PDF documents (single or multiple)")
db_btn = gr.Radio(["ChromaDB"], label="Vector database type", value="ChromaDB", type="index", info="Choose your vector database")
with gr.Accordion("Advanced options - Document text splitter", open=False):
slider_chunk_size = gr.Slider(minimum=100, maximum=1000, value=600, step=20, label="Chunk size", info="Chunk size", interactive=True)
slider_chunk_overlap = gr.Slider(minimum=10, maximum=200, value=40, step=10, label="Chunk overlap", info="Chunk overlap", interactive=True)
db_progress = gr.Textbox(label="Vector database initialization", value="None")
db_btn.click(initialize_database, inputs=[document, slider_chunk_size, slider_chunk_overlap], outputs=[vector_db, collection_name, db_progress])
with gr.Tab("Step 2 - Inizializzazione QA"):
llm_btn = gr.Radio(llm_names_simple, label="LLM models", value=llm_names_simple[0], type="index", info="Choose your LLM model")
with gr.Accordion("Advanced options - LLM model", open=False):
slider_temperature = gr.Slider(minimum=0.0, maximum=1.0, value=0.7, step=0.1, label="Temperature", info="Model temperature", interactive=True)
slider_maxtokens = gr.Slider(minimum=224, maximum=4096, value=1024, step=32, label="Max Tokens", info="Model max tokens", interactive=True)
slider_topk = gr.Slider(minimum=1, maximum=10, value=3, step=1, label="top-k samples", info="Model top-k samples", interactive=True)
llm_progress = gr.Textbox(value="None", label="QA chain initialization")
qachain_btn = gr.Button("Initialize question-answering chain...")
qachain_btn.click(initialize_LLM, inputs=[llm_btn, slider_temperature, slider_maxtokens, slider_topk, vector_db], outputs=[qa_chain, llm_progress]).then(lambda: [None, "", 0, "", 0], inputs=None, outputs=[chatbot, doc_source1, source1_page, doc_source2, source2_page], queue=False)
with gr.Tab("Step 3 - Conversazione con Chatbot"):
chatbot = gr.Chatbot(height=300)
with gr.Accordion("Advanced - Document references", open=True):
doc_source1 = gr.Textbox(label="Reference 1", lines=2, container=True, scale=20)
source1_page = gr.Number(label="Page", scale=1)
doc_source2 = gr.Textbox(label="Reference 2", lines=2, container=True, scale=20)
source2_page = gr.Number(label="Page", scale=1)
msg = gr.Textbox(placeholder="Type message", container=True)
submit_btn = gr.Button("Submit")
clear_btn = gr.ClearButton([msg, chatbot])
msg.submit(conversation, inputs=[qa_chain, msg, chatbot], outputs=[qa_chain, msg, chatbot, doc_source1, source1_page, doc_source2, source2_page], queue=False)
submit_btn.click(conversation, inputs=[qa_chain, msg, chatbot], outputs=[qa_chain, msg, chatbot, doc_source1, source1_page, doc_source2, source2_page], queue=False)
clear_btn.click(lambda: [None, "", 0, "", 0], inputs=None, outputs=[chatbot, doc_source1, source1_page, doc_source2, source2_page], queue=False)
demo.queue().launch(debug=True)
if __name__ == "__main__":
demo() |