File size: 7,777 Bytes
3ec9224
5be8df6
 
 
 
 
 
 
1ef8d7c
 
 
878c0a1
 
5be8df6
 
 
878c0a1
 
5be8df6
 
 
1ef8d7c
5be8df6
1ef8d7c
5be8df6
 
 
1ef8d7c
 
5be8df6
 
 
 
 
878c0a1
5be8df6
 
 
 
 
878c0a1
88fa380
878c0a1
 
eb94a8f
878c0a1
 
5be8df6
 
 
 
878c0a1
5be8df6
9733941
5be8df6
 
00bd139
5be8df6
 
 
1ef8d7c
5be8df6
 
 
1ef8d7c
5be8df6
1ef8d7c
5be8df6
 
878c0a1
00bd139
 
5be8df6
 
878c0a1
5be8df6
 
00bd139
5be8df6
 
9733941
 
 
 
 
 
 
00bd139
5be8df6
 
878c0a1
5be8df6
 
 
 
3ca2785
00bd139
1ef8d7c
878c0a1
 
 
 
 
 
a25f0eb
878c0a1
 
 
 
5be8df6
878c0a1
 
 
 
 
 
 
 
 
 
 
14155e5
878c0a1
 
 
 
 
 
 
 
 
 
 
 
5be8df6
323ccbe
5be8df6
 
878c0a1
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
import gradio as gr
import os
from langchain.document_loaders import PyPDFLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.vectorstores import Chroma
from langchain.chains import ConversationalRetrievalChain
from langchain.embeddings import HuggingFaceEmbeddings 
from langchain.llms import HuggingFaceHub
from pathlib import Path
import chromadb

llm_names = ["mistralai/Mixtral-8x7B-Instruct-v0.1"]
llm_names_simple = [os.path.basename(llm) for llm in llm_names]

def load_doc(list_file_path, chunk_size, chunk_overlap):
    loaders = [PyPDFLoader(x) for x in list_file_path]
    pages = [loader.load() for loader in loaders]
    text_splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap)
    doc_splits = text_splitter.split_documents(pages)
    return doc_splits

def create_db(splits, collection_name):
    embedding = HuggingFaceEmbeddings()
    new_client = chromadb.EphemeralClient()
    vectordb = Chroma.from_documents(
        documents=splits,
        embedding=embedding,
        client=new_client,
        collection_name=collection_name,
    )
    return vectordb

def load_db():
    embedding = HuggingFaceEmbeddings()
    vectordb = Chroma(embedding_function=embedding)
    return vectordb

def initialize_llmchain(llm_model, temperature, max_tokens, top_k, vector_db, progress=gr.Progress()):
    progress(0.1, desc="Initializing HF tokenizer...")
    progress(0.5, desc="Initializing HF Hub...")
    model_kwargs = {"temperature": temperature, "max_new_tokens": max_tokens, "top_k": top_k}
    if llm_model == "mistralai/Mixtral-8x7B-Instruct-v0.1":
        model_kwargs["load_in_8bit"] = True
    llm = HuggingFaceHub(repo_id=llm_model, model_kwargs=model_kwargs)
    progress(0.75, desc="Defining buffer memory...")
    memory = ConversationBufferMemory(memory_key="chat_history", output_key='answer', return_messages=True)
    retriever = vector_db.as_retriever()
    progress(0.8, desc="Defining retrieval chain...")
    qa_chain = ConversationalRetrievalChain.from_llm(
        llm,
        retriever=retriever,
        chain_type="stuff",
        memory=memory,
        return_source_documents=True,
    )
    progress(0.9, desc="Done!")
    return qa_chain

def initialize_database(list_file_obj, chunk_size, chunk_overlap, progress=gr.Progress()):
    list_file_path = [x.name for x in list_file_obj if x is not None]
    collection_name = Path(list_file_path[0]).stem
    progress(0.25, desc="Loading document...")
    doc_splits = load_doc(list_file_path, chunk_size, chunk_overlap)
    progress(0.5, desc="Generating vector database...")
    vector_db = create_db(doc_splits, collection_name)
    progress(0.9, desc="Done!")
    return vector_db, collection_name, "Complete!"

def initialize_LLM(llm_option, llm_temperature, max_tokens, top_k, vector_db, progress=gr.Progress()):
    llm_name = llm_names[llm_option]
    qa_chain = initialize_llmchain(llm_name, llm_temperature, max_tokens, top_k, vector_db, progress)
    return qa_chain, "Complete!"

def format_chat_history(message, chat_history):
    formatted_chat_history = [f"User: {user_message}\nAssistant: {bot_message}" for user_message, bot_message in chat_history]
    return formatted_chat_history

def conversation(qa_chain, message, history):
    formatted_chat_history = format_chat_history(message, history)
    response = qa_chain({"question": message, "chat_history": formatted_chat_history})
    response_answer = response["answer"]
    response_sources = response["source_documents"]
    response_source1 = response_sources[0].page_content.strip()
    response_source2 = response_sources[1].page_content.strip()
    response_source1_page = response_sources[0].metadata["page"] + 1
    response_source2_page = response_sources[1].metadata["page"] + 1
    new_history = history + [(message, response_answer)]
    return qa_chain, gr.update(value=""), new_history, response_source1, response_source1_page, response_source2, response_source2_page

def upload_file(file_obj):
    list_file_path = [file_obj.name for _ in file_obj]
    return list_file_path

def demo():
    with gr.Blocks(theme="base") as demo:
        vector_db = gr.State()
        qa_chain = gr.State()
        collection_name = gr.State()

        gr.Markdown("""<center><h2>ChatPDF</center></h2>""")

        with gr.Tab("Step 1 - Selezione PDF"):
            document = gr.Files(height=100, file_count="multiple", file_types=["pdf"], interactive=True, label="Upload your PDF documents (single or multiple)")
            db_btn = gr.Radio(["ChromaDB"], label="Vector database type", value="ChromaDB", type="index", info="Choose your vector database")
            with gr.Accordion("Advanced options - Document text splitter", open=False):
                slider_chunk_size = gr.Slider(minimum=100, maximum=1000, value=600, step=20, label="Chunk size", info="Chunk size", interactive=True)
                slider_chunk_overlap = gr.Slider(minimum=10, maximum=200, value=40, step=10, label="Chunk overlap", info="Chunk overlap", interactive=True)
            db_progress = gr.Textbox(label="Vector database initialization", value="None")
            db_btn.click(initialize_database, inputs=[document, slider_chunk_size, slider_chunk_overlap], outputs=[vector_db, collection_name, db_progress])

        with gr.Tab("Step 2 - Inizializzazione QA"):
            llm_btn = gr.Radio(llm_names_simple, label="LLM models", value=llm_names_simple[0], type="index", info="Choose your LLM model")
            with gr.Accordion("Advanced options - LLM model", open=False):
                slider_temperature = gr.Slider(minimum=0.0, maximum=1.0, value=0.7, step=0.1, label="Temperature", info="Model temperature", interactive=True)
                slider_maxtokens = gr.Slider(minimum=224, maximum=4096, value=1024, step=32, label="Max Tokens", info="Model max tokens", interactive=True)
                slider_topk = gr.Slider(minimum=1, maximum=10, value=3, step=1, label="top-k samples", info="Model top-k samples", interactive=True)
            llm_progress = gr.Textbox(value="None", label="QA chain initialization")
            qachain_btn = gr.Button("Initialize question-answering chain...")
            qachain_btn.click(initialize_LLM, inputs=[llm_btn, slider_temperature, slider_maxtokens, slider_topk, vector_db], outputs=[qa_chain, llm_progress]).then(lambda: [None, "", 0, "", 0], inputs=None, outputs=[chatbot, doc_source1, source1_page, doc_source2, source2_page], queue=False)

        with gr.Tab("Step 3 - Conversazione con Chatbot"):
            chatbot = gr.Chatbot(height=300)
            with gr.Accordion("Advanced - Document references", open=True):
                doc_source1 = gr.Textbox(label="Reference 1", lines=2, container=True, scale=20)
                source1_page = gr.Number(label="Page", scale=1)
                doc_source2 = gr.Textbox(label="Reference 2", lines=2, container=True, scale=20)
                source2_page = gr.Number(label="Page", scale=1)
            msg = gr.Textbox(placeholder="Type message", container=True)
            submit_btn = gr.Button("Submit")
            clear_btn = gr.ClearButton([msg, chatbot])

        msg.submit(conversation, inputs=[qa_chain, msg, chatbot], outputs=[qa_chain, msg, chatbot, doc_source1, source1_page, doc_source2, source2_page], queue=False)
        submit_btn.click(conversation, inputs=[qa_chain, msg, chatbot], outputs=[qa_chain, msg, chatbot, doc_source1, source1_page, doc_source2, source2_page], queue=False)
        clear_btn.click(lambda: [None, "", 0, "", 0], inputs=None, outputs=[chatbot, doc_source1, source1_page, doc_source2, source2_page], queue=False)

    demo.queue().launch(debug=True)

if __name__ == "__main__":
    demo()