Spaces:
Sleeping
Sleeping
import gradio as gr | |
import os | |
from pathlib import Path | |
import json | |
import csv | |
import pandas as pd | |
from tqdm import tqdm | |
api_token = os.getenv("HF_TOKEN") | |
from langchain_community.vectorstores import FAISS | |
from langchain_community.document_loaders import PyPDFLoader, TextLoader, CSVLoader, JSONLoader | |
from langchain.text_splitter import RecursiveCharacterTextSplitter | |
from langchain_community.vectorstores import Chroma | |
from langchain.chains import ConversationalRetrievalChain | |
from langchain_community.embeddings import HuggingFaceEmbeddings | |
from langchain_community.llms import HuggingFacePipeline | |
from langchain.chains import ConversationChain | |
from langchain.memory import ConversationBufferMemory | |
from langchain_community.llms import HuggingFaceEndpoint | |
import torch | |
# import spaces | |
list_llm = ["meta-llama/Meta-Llama-3-8B-Instruct", "mistralai/Mistral-7B-Instruct-v0.2"] | |
list_llm_simple = [os.path.basename(llm) for llm in list_llm] | |
# Load and split documents of various types | |
def load_doc(list_file_path, progress=gr.Progress()): | |
doc_splits = [] | |
progress(0, desc="Preparing to load documents") | |
total_files = len(list_file_path) | |
for i, file_path in enumerate(list_file_path): | |
progress((i/total_files) * 0.5, desc=f"Loading {Path(file_path).name}") | |
file_ext = Path(file_path).suffix.lower() | |
try: | |
# PDF documents | |
if file_ext == '.pdf': | |
loader = PyPDFLoader(file_path) | |
pages = loader.load() | |
doc_splits.extend(split_documents(pages)) | |
# Text-based documents | |
elif file_ext in ['.txt', '.md', '.py', '.js', '.html', '.css']: | |
loader = TextLoader(file_path) | |
documents = loader.load() | |
doc_splits.extend(split_documents(documents)) | |
# CSV files | |
elif file_ext == '.csv': | |
loader = CSVLoader(file_path) | |
documents = loader.load() | |
doc_splits.extend(split_documents(documents)) | |
# JSON files | |
elif file_ext in ['.json', '.jsonl']: | |
# For JSON, we need to determine if it's JSON or JSONL | |
with open(file_path, 'r') as f: | |
content = f.read().strip() | |
if content.startswith('[') or content.startswith('{'): | |
# Regular JSON | |
loader = JSONLoader( | |
file_path=file_path, | |
jq_schema='.', | |
text_content=False | |
) | |
documents = loader.load() | |
doc_splits.extend(split_documents(documents)) | |
else: | |
# JSONL - process line by line | |
documents = [] | |
with open(file_path, 'r') as f: | |
for line in f: | |
if line.strip(): | |
try: | |
json_obj = json.loads(line) | |
text = json.dumps(json_obj) | |
documents.append(text) | |
except json.JSONDecodeError: | |
continue | |
text_splitter = RecursiveCharacterTextSplitter( | |
chunk_size=1024, | |
chunk_overlap=64 | |
) | |
doc_splits.extend(text_splitter.create_documents(documents)) | |
except Exception as e: | |
print(f"Error processing {file_path}: {str(e)}") | |
continue | |
progress(0.5 + (i/total_files) * 0.5, desc=f"Processed {Path(file_path).name}") | |
return doc_splits | |
# Helper function to split documents | |
def split_documents(documents): | |
text_splitter = RecursiveCharacterTextSplitter( | |
chunk_size=1024, | |
chunk_overlap=64 | |
) | |
return text_splitter.split_documents(documents) | |
# Create vector database | |
def create_db(splits, progress=gr.Progress()): | |
progress(0, desc="Creating vector database") | |
embeddings = HuggingFaceEmbeddings() | |
# Create vectors with progress bar | |
total_chunks = len(splits) | |
vectordb = FAISS.from_documents( | |
documents=splits, | |
embedding=embeddings | |
) | |
progress(1.0, desc="Vector database creation complete") | |
return vectordb | |
# Initialize langchain LLM chain | |
def initialize_llmchain(llm_model, temperature, max_tokens, top_k, vector_db, progress=gr.Progress()): | |
progress(0, desc=f"Initializing {llm_model}") | |
if llm_model == "meta-llama/Meta-Llama-3-8B-Instruct": | |
llm = HuggingFaceEndpoint( | |
repo_id=llm_model, | |
huggingfacehub_api_token=api_token, | |
temperature=temperature, | |
max_new_tokens=max_tokens, | |
top_k=top_k, | |
) | |
else: | |
llm = HuggingFaceEndpoint( | |
huggingfacehub_api_token=api_token, | |
repo_id=llm_model, | |
temperature=temperature, | |
max_new_tokens=max_tokens, | |
top_k=top_k, | |
) | |
progress(0.5, desc="Setting up memory and retriever") | |
memory = ConversationBufferMemory( | |
memory_key="chat_history", | |
output_key='answer', | |
return_messages=True | |
) | |
retriever = vector_db.as_retriever() | |
qa_chain = ConversationalRetrievalChain.from_llm( | |
llm, | |
retriever=retriever, | |
chain_type="stuff", | |
memory=memory, | |
return_source_documents=True, | |
verbose=False, | |
) | |
progress(1.0, desc="LLM chain initialized") | |
return qa_chain | |
# Initialize database | |
def initialize_database(list_file_obj, progress=gr.Progress()): | |
# Create a list of documents (when valid) | |
list_file_path = [x.name for x in list_file_obj if x is not None] | |
if not list_file_path: | |
return None, "No valid files uploaded. Please upload at least one file." | |
# Load document and create splits | |
doc_splits = load_doc(list_file_path, progress) | |
if not doc_splits: | |
return None, "Could not extract any text from the uploaded files." | |
# Create or load vector database | |
vector_db = create_db(doc_splits, progress) | |
# Count documents by type | |
file_types = {} | |
for path in list_file_path: | |
ext = Path(path).suffix.lower() | |
file_types[ext] = file_types.get(ext, 0) + 1 | |
file_type_summary = ", ".join([f"{count} {ext}" for ext, count in file_types.items()]) | |
return vector_db, f"Database created with {len(doc_splits)} chunks from {len(list_file_path)} files ({file_type_summary})!" | |
# Initialize LLM | |
def initialize_LLM(llm_option, llm_temperature, max_tokens, top_k, vector_db, progress=gr.Progress()): | |
if vector_db is None: | |
return None, "Please create a vector database first!" | |
llm_name = list_llm[llm_option] | |
print("llm_name: ", llm_name) | |
qa_chain = initialize_llmchain(llm_name, llm_temperature, max_tokens, top_k, vector_db, progress) | |
return qa_chain, f"QA chain initialized with {llm_name}. Chatbot is ready!" | |
def format_chat_history(message, chat_history): | |
formatted_chat_history = [] | |
for user_message, bot_message in chat_history: | |
formatted_chat_history.append(f"User: {user_message}") | |
formatted_chat_history.append(f"Assistant: {bot_message}") | |
return formatted_chat_history | |
def conversation(qa_chain, message, history): | |
if qa_chain is None: | |
return None, gr.update(value=""), history, "Please initialize the chatbot first!", 0, "", 0, "", 0 | |
formatted_chat_history = format_chat_history(message, history) | |
# Generate response using QA chain | |
response = qa_chain.invoke({"question": message, "chat_history": formatted_chat_history}) | |
response_answer = response["answer"] | |
if response_answer.find("Helpful Answer:") != -1: | |
response_answer = response_answer.split("Helpful Answer:")[-1] | |
response_sources = response["source_documents"] | |
# Handle source documents | |
source_contents = ["", "", ""] | |
source_pages = [0, 0, 0] | |
for i, source in enumerate(response_sources[:3]): | |
source_contents[i] = source.page_content.strip() | |
# Check if the metadata contains a page number | |
if "page" in source.metadata: | |
source_pages[i] = source.metadata["page"] + 1 | |
elif "source" in source.metadata: | |
source_pages[i] = 1 | |
source_contents[i] = f"From: {source.metadata['source']}\n{source_contents[i]}" | |
# Append user message and response to chat history | |
new_history = history + [(message, response_answer)] | |
return qa_chain, gr.update(value=""), new_history, source_contents[0], source_pages[0], source_contents[1], source_pages[1], source_contents[2], source_pages[2] | |
def get_file_icon(file_path): | |
"""Return an appropriate emoji icon based on file extension""" | |
ext = Path(file_path).suffix.lower() | |
icons = { | |
'.pdf': '📄', | |
'.txt': '📝', | |
'.md': '📋', | |
'.py': '🐍', | |
'.js': '⚙️', | |
'.json': '📊', | |
'.jsonl': '📊', | |
'.csv': '📈', | |
'.html': '🌐', | |
'.css': '🎨', | |
} | |
return icons.get(ext, '📁') | |
def display_file_list(file_obj): | |
if not file_obj: | |
return "No files uploaded yet" | |
file_list = [f"{get_file_icon(x.name)} {Path(x.name).name}" for x in file_obj if x is not None] | |
return "\n".join(file_list) | |
def demo(): | |
with gr.Blocks(theme=gr.themes.Soft(primary_hue="indigo", secondary_hue="blue", neutral_hue="sky")) as demo: | |
vector_db = gr.State() | |
qa_chain = gr.State() | |
gr.HTML("<center><h1>📚 Enhanced RAG Chatbot</h1></center>") | |
gr.Markdown("""<b>Query your documents!</b> This enhanced AI agent performs retrieval augmented generation (RAG) on various document types | |
including PDFs, text files, markdown, code files, and structured data (CSV, JSON, JSONL). <b>Please do not upload confidential documents.</b> | |
""") | |
with gr.Row(): | |
with gr.Column(scale=86): | |
gr.Markdown("<b>Step 1 - Upload Documents and Initialize RAG Pipeline</b>") | |
with gr.Row(): | |
with gr.Column(scale=7): | |
document = gr.Files( | |
height=300, | |
file_count="multiple", | |
file_types=[".pdf", ".txt", ".md", ".py", ".js", ".json", ".jsonl", ".csv", ".html", ".css"], | |
interactive=True, | |
label="Upload Documents" | |
) | |
with gr.Column(scale=3): | |
file_list = gr.Textbox( | |
label="Uploaded Files", | |
value="No files uploaded yet", | |
interactive=False, | |
lines=12 | |
) | |
document.upload( | |
display_file_list, | |
inputs=[document], | |
outputs=[file_list] | |
) | |
with gr.Row(): | |
db_btn = gr.Button("Create Vector Database", variant="primary") | |
with gr.Row(): | |
db_progress = gr.Textbox( | |
value="Not initialized", | |
show_label=False, | |
container=True | |
) | |
gr.Markdown("<b>Step 2 - Select LLM and Parameters</b>") | |
with gr.Row(): | |
llm_btn = gr.Radio( | |
list_llm_simple, | |
label="Available LLMs", | |
value=list_llm_simple[0], | |
type="index" | |
) | |
with gr.Row(): | |
with gr.Accordion("LLM Parameters", open=False): | |
with gr.Row(): | |
slider_temperature = gr.Slider( | |
minimum=0.01, | |
maximum=1.0, | |
value=0.5, | |
step=0.1, | |
label="Temperature", | |
info="Controls randomness in generation", | |
interactive=True | |
) | |
with gr.Row(): | |
slider_maxtokens = gr.Slider( | |
minimum=128, | |
maximum=9192, | |
value=4096, | |
step=128, | |
label="Max New Tokens", | |
info="Maximum tokens to generate", | |
interactive=True | |
) | |
with gr.Row(): | |
slider_topk = gr.Slider( | |
minimum=1, | |
maximum=10, | |
value=3, | |
step=1, | |
label="Top-k", | |
info="Number of tokens to consider", | |
interactive=True | |
) | |
with gr.Row(): | |
qachain_btn = gr.Button("Initialize Chatbot", variant="primary") | |
with gr.Row(): | |
llm_progress = gr.Textbox( | |
value="Not initialized", | |
show_label=False, | |
container=True | |
) | |
with gr.Column(scale=200): | |
gr.Markdown("<b>Step 3 - Chat with Your Documents</b>") | |
chatbot = gr.Chatbot(height=505) | |
with gr.Accordion("Relevant Context from Documents", open=False): | |
with gr.Row(): | |
doc_source1 = gr.Textbox(label="Reference 1", lines=2, container=True, scale=20) | |
source1_page = gr.Number(label="Page", scale=1) | |
with gr.Row(): | |
doc_source2 = gr.Textbox(label="Reference 2", lines=2, container=True, scale=20) | |
source2_page = gr.Number(label="Page", scale=1) | |
with gr.Row(): | |
doc_source3 = gr.Textbox(label="Reference 3", lines=2, container=True, scale=20) | |
source3_page = gr.Number(label="Page", scale=1) | |
with gr.Row(): | |
msg = gr.Textbox( | |
placeholder="Ask a question about your documents...", | |
container=True, | |
lines=2 | |
) | |
with gr.Row(): | |
submit_btn = gr.Button("Submit", variant="primary") | |
clear_btn = gr.ClearButton([msg, chatbot], value="Clear") | |
# Preprocessing events | |
db_btn.click( | |
initialize_database, | |
inputs=[document], | |
outputs=[vector_db, db_progress] | |
) | |
qachain_btn.click( | |
initialize_LLM, | |
inputs=[llm_btn, slider_temperature, slider_maxtokens, slider_topk, vector_db], | |
outputs=[qa_chain, llm_progress] | |
).then( | |
lambda:[None,"",0,"",0,"",0], | |
inputs=None, | |
outputs=[chatbot, doc_source1, source1_page, doc_source2, source2_page, doc_source3, source3_page], | |
queue=False | |
) | |
# Chatbot events | |
msg.submit( | |
conversation, | |
inputs=[qa_chain, msg, chatbot], | |
outputs=[qa_chain, msg, chatbot, doc_source1, source1_page, doc_source2, source2_page, doc_source3, source3_page], | |
queue=False | |
) | |
submit_btn.click( | |
conversation, | |
inputs=[qa_chain, msg, chatbot], | |
outputs=[qa_chain, msg, chatbot, doc_source1, source1_page, doc_source2, source2_page, doc_source3, source3_page], | |
queue=False | |
) | |
clear_btn.click( | |
lambda:[None,"",0,"",0,"",0], | |
inputs=None, | |
outputs=[chatbot, doc_source1, source1_page, doc_source2, source2_page, doc_source3, source3_page], | |
queue=False | |
) | |
demo.queue().launch(debug=True) | |
if __name__ == "__main__": | |
demo() |