|
|
|
import gradio as gr |
|
import shutil |
|
import tempfile |
|
from pathlib import Path |
|
import time |
|
|
|
|
|
from Chroma import create_db |
|
from LangChain import query, load_chain |
|
|
|
|
|
|
|
def load_data( |
|
chunk_size, |
|
chunk_overlap, |
|
uploaded_files, |
|
existing_data, |
|
progress=gr.Progress(), |
|
): |
|
try: |
|
progress(0, desc="Loading chain...") |
|
time.sleep(0.5) |
|
print("Loading chain...") |
|
|
|
chain = load_chain() |
|
progress(0.3, desc="Chain loaded") |
|
time.sleep(0.5) |
|
print("Chain loaded") |
|
|
|
print("Creating db...") |
|
|
|
if existing_data and "temp_dir" in existing_data: |
|
shutil.rmtree(existing_data["temp_dir"]) |
|
|
|
|
|
temp_dir = tempfile.mkdtemp() |
|
|
|
print(f"Copying files to {temp_dir}...") |
|
|
|
for i, uploaded_file in enumerate(uploaded_files, 1): |
|
src_path = Path(uploaded_file.name) |
|
|
|
shutil.move(src_path, temp_dir) |
|
|
|
progress( |
|
0.3 + 0.2 * i / len(uploaded_files), f"Processing {uploaded_file.name.split('/')[-1]}" |
|
) |
|
time.sleep(0.1) |
|
|
|
|
|
progress(0.5, desc="Creating db...") |
|
db = create_db(chunk_size, chunk_overlap, INPUT_PATH=temp_dir, CHROMA_PATH=temp_dir) |
|
progress(1.0, desc="DB created") |
|
print("DB created") |
|
|
|
return { |
|
"db": db, |
|
"chain": chain, |
|
"temp_dir": temp_dir, |
|
"loaded": True, |
|
"file_count": len(uploaded_files), |
|
}, "β
Data loaded successfully!" |
|
except Exception as e: |
|
return {"loaded": False, "error": str(e)}, f"β Error: {str(e)}" |
|
|
|
|
|
def chat_response(message, chat_history, data): |
|
if not data or not data.get("loaded"): |
|
error_msg = data.get("error", "Please load data first!") |
|
chat_history.append((message, error_msg)) |
|
return chat_history |
|
|
|
|
|
answer, sources = query(message, data["db"], data["chain"]) |
|
sources = "\n".join([s_file.split("/")[-1] for s_file in sources.split("\n")]) |
|
response = f"{answer}\n\nSources:\n{sources}" |
|
|
|
|
|
chat_history.append((message, response)) |
|
return chat_history |
|
|
|
|
|
with gr.Blocks(title="Document Analysis Chatbot") as demo: |
|
|
|
data_store = gr.State() |
|
|
|
with gr.Row(): |
|
|
|
with gr.Column(scale=1): |
|
gr.Markdown("## Data Upload") |
|
|
|
chunk_size = gr.Number(label="Chunk Size", value=1000) |
|
chunk_overlap = gr.Number(label="Chunk Overlap", value=500) |
|
|
|
folder_input = gr.File(file_count="directory", label="Upload Folder") |
|
|
|
status_text = gr.Textbox( |
|
label="Status", |
|
interactive=False, |
|
show_label=False |
|
) |
|
|
|
load_btn = gr.Button("Load Data", variant="primary") |
|
|
|
|
|
with gr.Column(scale=3, visible=False) as chat_col: |
|
gr.Markdown("## Chat Interface") |
|
chatbot = gr.Chatbot( |
|
label="Document Analysis Chat", |
|
type="tuples", |
|
bubble_full_width=False, |
|
render_markdown=True, |
|
height=500, |
|
) |
|
msg = gr.Textbox(label="Your Question", placeholder="Type your question...") |
|
clear_btn = gr.Button("Clear Chat", variant="secondary") |
|
|
|
|
|
load_btn.click( |
|
fn=load_data, |
|
inputs=[chunk_size, chunk_overlap, folder_input, data_store], |
|
outputs=[data_store, status_text], |
|
).then(fn=lambda: gr.Column(visible=True), outputs=chat_col) |
|
|
|
|
|
msg.submit( |
|
fn=chat_response, |
|
inputs=[msg, chatbot, data_store], |
|
outputs=[chatbot], |
|
).then(lambda: "", None, msg) |
|
|
|
|
|
clear_btn.click(lambda: [], None, chatbot) |
|
|
|
if __name__ == "__main__": |
|
demo.launch() |
|
|