gradio app
Browse files- Chroma.py +59 -0
- query.py → LangChain.py +34 -13
- app.py +137 -0
- create_db.py +0 -42
- requirements.txt +2 -0
- run.sh +2 -1
Chroma.py
ADDED
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from langchain_community.document_loaders import DirectoryLoader
|
2 |
+
from langchain_text_splitters import RecursiveCharacterTextSplitter
|
3 |
+
from langchain_huggingface import HuggingFaceEmbeddings
|
4 |
+
from langchain_chroma import Chroma
|
5 |
+
import os
|
6 |
+
import shutil
|
7 |
+
|
8 |
+
|
9 |
+
def create_db(
|
10 |
+
chunk_size,
|
11 |
+
chunk_overlap,
|
12 |
+
INPUT_PATH="./data/books/",
|
13 |
+
INPUT_GLOB=["*.txt", "*.md"],
|
14 |
+
MODEL_NAME="Alibaba-NLP/gte-multilingual-base",
|
15 |
+
CHROMA_PATH="./chromadb/",
|
16 |
+
):
|
17 |
+
# setup embeddings
|
18 |
+
embeddings = HuggingFaceEmbeddings(
|
19 |
+
model_name=MODEL_NAME,
|
20 |
+
model_kwargs={"device": "cuda", "trust_remote_code": True},
|
21 |
+
encode_kwargs={"normalize_embeddings": True},
|
22 |
+
)
|
23 |
+
|
24 |
+
# load documents
|
25 |
+
raw_documents = DirectoryLoader(INPUT_PATH, glob=INPUT_GLOB).load()
|
26 |
+
text_splitter = RecursiveCharacterTextSplitter(
|
27 |
+
chunk_size=chunk_size,
|
28 |
+
chunk_overlap=chunk_overlap,
|
29 |
+
length_function=len,
|
30 |
+
add_start_index=True,
|
31 |
+
)
|
32 |
+
documents = text_splitter.split_documents(raw_documents)
|
33 |
+
print(f"Split {len(raw_documents)} documents into {len(documents)} chunks.")
|
34 |
+
|
35 |
+
# Clear out the database first.
|
36 |
+
if os.path.exists(CHROMA_PATH):
|
37 |
+
shutil.rmtree(CHROMA_PATH)
|
38 |
+
|
39 |
+
# Create a new DB from the documents.
|
40 |
+
db = Chroma.from_documents(
|
41 |
+
documents,
|
42 |
+
embeddings,
|
43 |
+
persist_directory=CHROMA_PATH,
|
44 |
+
collection_metadata={"hnsw:space": "cosine"},
|
45 |
+
)
|
46 |
+
print(f"Saved {len(documents)} chunks to {CHROMA_PATH}.")
|
47 |
+
|
48 |
+
return db
|
49 |
+
|
50 |
+
|
51 |
+
if __name__ == "__main__":
|
52 |
+
create_db(
|
53 |
+
1000,
|
54 |
+
500,
|
55 |
+
INPUT_PATH="./data/books/dracula_segmented/",
|
56 |
+
INPUT_GLOB=["*.txt"],
|
57 |
+
MODEL_NAME="Alibaba-NLP/gte-multilingual-base",
|
58 |
+
CHROMA_PATH="./chromadb/",
|
59 |
+
)
|
query.py → LangChain.py
RENAMED
@@ -6,13 +6,8 @@ from langchain_core.prompts import ChatPromptTemplate
|
|
6 |
from langchain_huggingface import HuggingFaceEmbeddings
|
7 |
from langchain_chroma import Chroma
|
8 |
|
9 |
-
CHROMA_PATH = "chromadb/"
|
10 |
|
11 |
-
|
12 |
-
MODEL_NAME = "Alibaba-NLP/gte-multilingual-base"
|
13 |
-
|
14 |
-
|
15 |
-
def load_db():
|
16 |
# setup embeddings
|
17 |
embeddings = HuggingFaceEmbeddings(
|
18 |
model_name=MODEL_NAME,
|
@@ -41,13 +36,7 @@ def query_db(db, query_text):
|
|
41 |
return context_text, sources
|
42 |
|
43 |
|
44 |
-
|
45 |
-
db = load_db()
|
46 |
-
|
47 |
-
question = "Cor do cabelo de Van Helsing"
|
48 |
-
|
49 |
-
context, sources = query_db(db, question)
|
50 |
-
|
51 |
# prompt chat
|
52 |
prompt = ChatPromptTemplate(
|
53 |
[
|
@@ -74,6 +63,38 @@ Answer the question based on the above context in question's original language:
|
|
74 |
# pipeline
|
75 |
chain = prompt | llm
|
76 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
77 |
print(f"Context:\n{context}\n*************************")
|
78 |
|
79 |
# ask
|
|
|
6 |
from langchain_huggingface import HuggingFaceEmbeddings
|
7 |
from langchain_chroma import Chroma
|
8 |
|
|
|
9 |
|
10 |
+
def load_db(CHROMA_PATH="chromadb/", MODEL_NAME="Alibaba-NLP/gte-multilingual-base"):
|
|
|
|
|
|
|
|
|
11 |
# setup embeddings
|
12 |
embeddings = HuggingFaceEmbeddings(
|
13 |
model_name=MODEL_NAME,
|
|
|
36 |
return context_text, sources
|
37 |
|
38 |
|
39 |
+
def load_chain():
|
|
|
|
|
|
|
|
|
|
|
|
|
40 |
# prompt chat
|
41 |
prompt = ChatPromptTemplate(
|
42 |
[
|
|
|
63 |
# pipeline
|
64 |
chain = prompt | llm
|
65 |
|
66 |
+
return chain
|
67 |
+
|
68 |
+
|
69 |
+
def query(question, db, chain):
|
70 |
+
context, sources = query_db(db, question)
|
71 |
+
|
72 |
+
print(f"Context:\n{context}\n*************************")
|
73 |
+
|
74 |
+
# ask
|
75 |
+
answer = chain.invoke(
|
76 |
+
{
|
77 |
+
"context": context,
|
78 |
+
"question": question,
|
79 |
+
}
|
80 |
+
).content
|
81 |
+
print(f"Answer:\n{answer}\n*************************")
|
82 |
+
|
83 |
+
print(f"Sources:\n{sources}")
|
84 |
+
|
85 |
+
return answer, sources
|
86 |
+
|
87 |
+
|
88 |
+
if __name__ == "__main__":
|
89 |
+
db = load_db()
|
90 |
+
|
91 |
+
question = "Cor do cabelo de Van Helsing"
|
92 |
+
|
93 |
+
context, sources = query_db(db, question)
|
94 |
+
|
95 |
+
# model creation
|
96 |
+
chain = load_chain()
|
97 |
+
|
98 |
print(f"Context:\n{context}\n*************************")
|
99 |
|
100 |
# ask
|
app.py
ADDED
@@ -0,0 +1,137 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# app interface related
|
2 |
+
import gradio as gr
|
3 |
+
import shutil
|
4 |
+
import tempfile
|
5 |
+
from pathlib import Path
|
6 |
+
import time
|
7 |
+
|
8 |
+
# ai related
|
9 |
+
from Chroma import create_db
|
10 |
+
from LangChain import query, load_chain
|
11 |
+
|
12 |
+
|
13 |
+
# function to store the state
|
14 |
+
def load_data(
|
15 |
+
chunk_size,
|
16 |
+
chunk_overlap,
|
17 |
+
uploaded_files,
|
18 |
+
existing_data,
|
19 |
+
progress=gr.Progress(),
|
20 |
+
):
|
21 |
+
try:
|
22 |
+
progress(0, desc="Loading chain...")
|
23 |
+
time.sleep(0.5)
|
24 |
+
print("Loading chain...")
|
25 |
+
# chain load
|
26 |
+
chain = load_chain()
|
27 |
+
progress(0.3, desc="Chain loaded")
|
28 |
+
time.sleep(0.5)
|
29 |
+
print("Chain loaded")
|
30 |
+
|
31 |
+
print("Creating db...")
|
32 |
+
# clean up previous temporary directory if it exists
|
33 |
+
if existing_data and "temp_dir" in existing_data:
|
34 |
+
shutil.rmtree(existing_data["temp_dir"])
|
35 |
+
|
36 |
+
# create new consolidated temporary directory
|
37 |
+
temp_dir = tempfile.mkdtemp()
|
38 |
+
|
39 |
+
print(f"Copying files to {temp_dir}...")
|
40 |
+
# preserve original directory structure
|
41 |
+
for i, uploaded_file in enumerate(uploaded_files, 1):
|
42 |
+
src_path = Path(uploaded_file.name)
|
43 |
+
# move file to consolidated directory
|
44 |
+
shutil.move(src_path, temp_dir)
|
45 |
+
# update progress bar
|
46 |
+
progress(
|
47 |
+
0.3 + 0.2 * i / len(uploaded_files), f"Processing {uploaded_file.name.split('/')[-1]}"
|
48 |
+
)
|
49 |
+
time.sleep(0.1)
|
50 |
+
|
51 |
+
# create db file
|
52 |
+
progress(0.5, desc="Creating db...")
|
53 |
+
db = create_db(chunk_size, chunk_overlap, INPUT_PATH=temp_dir, CHROMA_PATH=temp_dir)
|
54 |
+
progress(1.0, desc="DB created")
|
55 |
+
print("DB created")
|
56 |
+
|
57 |
+
return {
|
58 |
+
"db": db,
|
59 |
+
"chain": chain,
|
60 |
+
"temp_dir": temp_dir,
|
61 |
+
"loaded": True,
|
62 |
+
"file_count": len(uploaded_files),
|
63 |
+
}, "✅ Data loaded successfully!"
|
64 |
+
except Exception as e:
|
65 |
+
return {"loaded": False, "error": str(e)}, f"❌ Error: {str(e)}"
|
66 |
+
|
67 |
+
|
68 |
+
def chat_response(message, chat_history, data):
|
69 |
+
if not data or not data.get("loaded"):
|
70 |
+
error_msg = data.get("error", "Please load data first!")
|
71 |
+
chat_history.append((message, error_msg))
|
72 |
+
return chat_history
|
73 |
+
|
74 |
+
# responses based on the input data
|
75 |
+
answer, sources = query(message, data["db"], data["chain"])
|
76 |
+
sources = "\n".join([s_file.split("/")[-1] for s_file in sources.split("\n")])
|
77 |
+
response = f"{answer}\n\nSources:\n{sources}"
|
78 |
+
|
79 |
+
# Append messages as tuples (user, assistant) instead of dictionaries
|
80 |
+
chat_history.append((message, response))
|
81 |
+
return chat_history
|
82 |
+
|
83 |
+
|
84 |
+
with gr.Blocks(title="Document Analysis Chatbot") as demo:
|
85 |
+
# store loaded data
|
86 |
+
data_store = gr.State()
|
87 |
+
|
88 |
+
with gr.Row():
|
89 |
+
# Left Column - Inputs
|
90 |
+
with gr.Column(scale=1):
|
91 |
+
gr.Markdown("## Data Upload")
|
92 |
+
# create db parameters
|
93 |
+
chunk_size = gr.Number(label="Chunk Size", value=1000)
|
94 |
+
chunk_overlap = gr.Number(label="Chunk Overlap", value=500)
|
95 |
+
# load file
|
96 |
+
folder_input = gr.File(file_count="directory", label="Upload Folder")
|
97 |
+
# Add status display
|
98 |
+
status_text = gr.Textbox(
|
99 |
+
label="Status",
|
100 |
+
interactive=False,
|
101 |
+
show_label=False
|
102 |
+
)
|
103 |
+
# load button
|
104 |
+
load_btn = gr.Button("Load Data", variant="primary")
|
105 |
+
|
106 |
+
# Right Column - Chat
|
107 |
+
with gr.Column(scale=3, visible=False) as chat_col:
|
108 |
+
gr.Markdown("## Chat Interface")
|
109 |
+
chatbot = gr.Chatbot(
|
110 |
+
label="Document Analysis Chat",
|
111 |
+
type="tuples",
|
112 |
+
bubble_full_width=False, # Prevent stretching of messages
|
113 |
+
render_markdown=True, # Handle markdown formatting properly,
|
114 |
+
height=500,
|
115 |
+
)
|
116 |
+
msg = gr.Textbox(label="Your Question", placeholder="Type your question...")
|
117 |
+
clear_btn = gr.Button("Clear Chat", variant="secondary")
|
118 |
+
|
119 |
+
# Loading indicators - update to handle multiple outputs
|
120 |
+
load_btn.click(
|
121 |
+
fn=load_data,
|
122 |
+
inputs=[chunk_size, chunk_overlap, folder_input, data_store],
|
123 |
+
outputs=[data_store, status_text],
|
124 |
+
).then(fn=lambda: gr.Column(visible=True), outputs=chat_col)
|
125 |
+
|
126 |
+
# Chat interaction
|
127 |
+
msg.submit(
|
128 |
+
fn=chat_response,
|
129 |
+
inputs=[msg, chatbot, data_store],
|
130 |
+
outputs=[chatbot],
|
131 |
+
).then(lambda: "", None, msg)
|
132 |
+
|
133 |
+
# Clear chat
|
134 |
+
clear_btn.click(lambda: [], None, chatbot)
|
135 |
+
|
136 |
+
if __name__ == "__main__":
|
137 |
+
demo.launch()
|
create_db.py
DELETED
@@ -1,42 +0,0 @@
|
|
1 |
-
from langchain_community.document_loaders import DirectoryLoader
|
2 |
-
from langchain_text_splitters import RecursiveCharacterTextSplitter
|
3 |
-
from langchain_huggingface import HuggingFaceEmbeddings
|
4 |
-
from langchain_chroma import Chroma
|
5 |
-
import os
|
6 |
-
import shutil
|
7 |
-
|
8 |
-
CHROMA_PATH = "chromadb/"
|
9 |
-
|
10 |
-
INPUT_PATH = "./data/books/dracula_segmented/"
|
11 |
-
INPUT_GLOB = "*.txt"
|
12 |
-
|
13 |
-
# free models
|
14 |
-
MODEL_NAME = "Alibaba-NLP/gte-multilingual-base"
|
15 |
-
|
16 |
-
# setup embeddings
|
17 |
-
embeddings = HuggingFaceEmbeddings(
|
18 |
-
model_name=MODEL_NAME,
|
19 |
-
model_kwargs={"device": "cuda", "trust_remote_code": True},
|
20 |
-
encode_kwargs={"normalize_embeddings": True},
|
21 |
-
)
|
22 |
-
|
23 |
-
# load documents
|
24 |
-
raw_documents = DirectoryLoader(INPUT_PATH, glob=INPUT_GLOB).load()
|
25 |
-
text_splitter = RecursiveCharacterTextSplitter(
|
26 |
-
chunk_size=1000, chunk_overlap=500, length_function=len, add_start_index=True
|
27 |
-
)
|
28 |
-
documents = text_splitter.split_documents(raw_documents)
|
29 |
-
print(f"Split {len(raw_documents)} documents into {len(documents)} chunks.")
|
30 |
-
|
31 |
-
# Clear out the database first.
|
32 |
-
if os.path.exists(CHROMA_PATH):
|
33 |
-
shutil.rmtree(CHROMA_PATH)
|
34 |
-
|
35 |
-
# Create a new DB from the documents.
|
36 |
-
db = Chroma.from_documents(
|
37 |
-
documents,
|
38 |
-
embeddings,
|
39 |
-
persist_directory=CHROMA_PATH,
|
40 |
-
collection_metadata={"hnsw:space": "cosine"},
|
41 |
-
)
|
42 |
-
print(f"Saved {len(documents)} chunks to {CHROMA_PATH}.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
requirements.txt
CHANGED
@@ -4,6 +4,8 @@ langchain-huggingface==0.1.2
|
|
4 |
langchain-chroma==0.2.2 # vectorized documents for query
|
5 |
sentence-transformers==3.4.1 # free embeddings
|
6 |
unstructured==0.16.23
|
|
|
|
|
7 |
# QWEN libs
|
8 |
torch==2.4.1
|
9 |
triton==3.0.0
|
|
|
4 |
langchain-chroma==0.2.2 # vectorized documents for query
|
5 |
sentence-transformers==3.4.1 # free embeddings
|
6 |
unstructured==0.16.23
|
7 |
+
unstructured[md]==0.16.23
|
8 |
+
gradio==5.19.0
|
9 |
# QWEN libs
|
10 |
torch==2.4.1
|
11 |
triton==3.0.0
|
run.sh
CHANGED
@@ -13,6 +13,7 @@ fi
|
|
13 |
eval "$($CONDA_PATH shell.bash hook)"
|
14 |
|
15 |
conda activate specialist
|
16 |
-
python
|
|
|
17 |
conda deactivate
|
18 |
echo "Completed."
|
|
|
13 |
eval "$($CONDA_PATH shell.bash hook)"
|
14 |
|
15 |
conda activate specialist
|
16 |
+
# python Chroma.py
|
17 |
+
python LangChain.py
|
18 |
conda deactivate
|
19 |
echo "Completed."
|