merligus commited on
Commit
d3a1fe2
·
1 Parent(s): 2566a62

gradio app

Browse files
Files changed (6) hide show
  1. Chroma.py +59 -0
  2. query.py → LangChain.py +34 -13
  3. app.py +137 -0
  4. create_db.py +0 -42
  5. requirements.txt +2 -0
  6. run.sh +2 -1
Chroma.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langchain_community.document_loaders import DirectoryLoader
2
+ from langchain_text_splitters import RecursiveCharacterTextSplitter
3
+ from langchain_huggingface import HuggingFaceEmbeddings
4
+ from langchain_chroma import Chroma
5
+ import os
6
+ import shutil
7
+
8
+
9
+ def create_db(
10
+ chunk_size,
11
+ chunk_overlap,
12
+ INPUT_PATH="./data/books/",
13
+ INPUT_GLOB=["*.txt", "*.md"],
14
+ MODEL_NAME="Alibaba-NLP/gte-multilingual-base",
15
+ CHROMA_PATH="./chromadb/",
16
+ ):
17
+ # setup embeddings
18
+ embeddings = HuggingFaceEmbeddings(
19
+ model_name=MODEL_NAME,
20
+ model_kwargs={"device": "cuda", "trust_remote_code": True},
21
+ encode_kwargs={"normalize_embeddings": True},
22
+ )
23
+
24
+ # load documents
25
+ raw_documents = DirectoryLoader(INPUT_PATH, glob=INPUT_GLOB).load()
26
+ text_splitter = RecursiveCharacterTextSplitter(
27
+ chunk_size=chunk_size,
28
+ chunk_overlap=chunk_overlap,
29
+ length_function=len,
30
+ add_start_index=True,
31
+ )
32
+ documents = text_splitter.split_documents(raw_documents)
33
+ print(f"Split {len(raw_documents)} documents into {len(documents)} chunks.")
34
+
35
+ # Clear out the database first.
36
+ if os.path.exists(CHROMA_PATH):
37
+ shutil.rmtree(CHROMA_PATH)
38
+
39
+ # Create a new DB from the documents.
40
+ db = Chroma.from_documents(
41
+ documents,
42
+ embeddings,
43
+ persist_directory=CHROMA_PATH,
44
+ collection_metadata={"hnsw:space": "cosine"},
45
+ )
46
+ print(f"Saved {len(documents)} chunks to {CHROMA_PATH}.")
47
+
48
+ return db
49
+
50
+
51
+ if __name__ == "__main__":
52
+ create_db(
53
+ 1000,
54
+ 500,
55
+ INPUT_PATH="./data/books/dracula_segmented/",
56
+ INPUT_GLOB=["*.txt"],
57
+ MODEL_NAME="Alibaba-NLP/gte-multilingual-base",
58
+ CHROMA_PATH="./chromadb/",
59
+ )
query.py → LangChain.py RENAMED
@@ -6,13 +6,8 @@ from langchain_core.prompts import ChatPromptTemplate
6
  from langchain_huggingface import HuggingFaceEmbeddings
7
  from langchain_chroma import Chroma
8
 
9
- CHROMA_PATH = "chromadb/"
10
 
11
- # free model
12
- MODEL_NAME = "Alibaba-NLP/gte-multilingual-base"
13
-
14
-
15
- def load_db():
16
  # setup embeddings
17
  embeddings = HuggingFaceEmbeddings(
18
  model_name=MODEL_NAME,
@@ -41,13 +36,7 @@ def query_db(db, query_text):
41
  return context_text, sources
42
 
43
 
44
- if __name__ == "__main__":
45
- db = load_db()
46
-
47
- question = "Cor do cabelo de Van Helsing"
48
-
49
- context, sources = query_db(db, question)
50
-
51
  # prompt chat
52
  prompt = ChatPromptTemplate(
53
  [
@@ -74,6 +63,38 @@ Answer the question based on the above context in question's original language:
74
  # pipeline
75
  chain = prompt | llm
76
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77
  print(f"Context:\n{context}\n*************************")
78
 
79
  # ask
 
6
  from langchain_huggingface import HuggingFaceEmbeddings
7
  from langchain_chroma import Chroma
8
 
 
9
 
10
+ def load_db(CHROMA_PATH="chromadb/", MODEL_NAME="Alibaba-NLP/gte-multilingual-base"):
 
 
 
 
11
  # setup embeddings
12
  embeddings = HuggingFaceEmbeddings(
13
  model_name=MODEL_NAME,
 
36
  return context_text, sources
37
 
38
 
39
+ def load_chain():
 
 
 
 
 
 
40
  # prompt chat
41
  prompt = ChatPromptTemplate(
42
  [
 
63
  # pipeline
64
  chain = prompt | llm
65
 
66
+ return chain
67
+
68
+
69
+ def query(question, db, chain):
70
+ context, sources = query_db(db, question)
71
+
72
+ print(f"Context:\n{context}\n*************************")
73
+
74
+ # ask
75
+ answer = chain.invoke(
76
+ {
77
+ "context": context,
78
+ "question": question,
79
+ }
80
+ ).content
81
+ print(f"Answer:\n{answer}\n*************************")
82
+
83
+ print(f"Sources:\n{sources}")
84
+
85
+ return answer, sources
86
+
87
+
88
+ if __name__ == "__main__":
89
+ db = load_db()
90
+
91
+ question = "Cor do cabelo de Van Helsing"
92
+
93
+ context, sources = query_db(db, question)
94
+
95
+ # model creation
96
+ chain = load_chain()
97
+
98
  print(f"Context:\n{context}\n*************************")
99
 
100
  # ask
app.py ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app interface related
2
+ import gradio as gr
3
+ import shutil
4
+ import tempfile
5
+ from pathlib import Path
6
+ import time
7
+
8
+ # ai related
9
+ from Chroma import create_db
10
+ from LangChain import query, load_chain
11
+
12
+
13
+ # function to store the state
14
+ def load_data(
15
+ chunk_size,
16
+ chunk_overlap,
17
+ uploaded_files,
18
+ existing_data,
19
+ progress=gr.Progress(),
20
+ ):
21
+ try:
22
+ progress(0, desc="Loading chain...")
23
+ time.sleep(0.5)
24
+ print("Loading chain...")
25
+ # chain load
26
+ chain = load_chain()
27
+ progress(0.3, desc="Chain loaded")
28
+ time.sleep(0.5)
29
+ print("Chain loaded")
30
+
31
+ print("Creating db...")
32
+ # clean up previous temporary directory if it exists
33
+ if existing_data and "temp_dir" in existing_data:
34
+ shutil.rmtree(existing_data["temp_dir"])
35
+
36
+ # create new consolidated temporary directory
37
+ temp_dir = tempfile.mkdtemp()
38
+
39
+ print(f"Copying files to {temp_dir}...")
40
+ # preserve original directory structure
41
+ for i, uploaded_file in enumerate(uploaded_files, 1):
42
+ src_path = Path(uploaded_file.name)
43
+ # move file to consolidated directory
44
+ shutil.move(src_path, temp_dir)
45
+ # update progress bar
46
+ progress(
47
+ 0.3 + 0.2 * i / len(uploaded_files), f"Processing {uploaded_file.name.split('/')[-1]}"
48
+ )
49
+ time.sleep(0.1)
50
+
51
+ # create db file
52
+ progress(0.5, desc="Creating db...")
53
+ db = create_db(chunk_size, chunk_overlap, INPUT_PATH=temp_dir, CHROMA_PATH=temp_dir)
54
+ progress(1.0, desc="DB created")
55
+ print("DB created")
56
+
57
+ return {
58
+ "db": db,
59
+ "chain": chain,
60
+ "temp_dir": temp_dir,
61
+ "loaded": True,
62
+ "file_count": len(uploaded_files),
63
+ }, "✅ Data loaded successfully!"
64
+ except Exception as e:
65
+ return {"loaded": False, "error": str(e)}, f"❌ Error: {str(e)}"
66
+
67
+
68
+ def chat_response(message, chat_history, data):
69
+ if not data or not data.get("loaded"):
70
+ error_msg = data.get("error", "Please load data first!")
71
+ chat_history.append((message, error_msg))
72
+ return chat_history
73
+
74
+ # responses based on the input data
75
+ answer, sources = query(message, data["db"], data["chain"])
76
+ sources = "\n".join([s_file.split("/")[-1] for s_file in sources.split("\n")])
77
+ response = f"{answer}\n\nSources:\n{sources}"
78
+
79
+ # Append messages as tuples (user, assistant) instead of dictionaries
80
+ chat_history.append((message, response))
81
+ return chat_history
82
+
83
+
84
+ with gr.Blocks(title="Document Analysis Chatbot") as demo:
85
+ # store loaded data
86
+ data_store = gr.State()
87
+
88
+ with gr.Row():
89
+ # Left Column - Inputs
90
+ with gr.Column(scale=1):
91
+ gr.Markdown("## Data Upload")
92
+ # create db parameters
93
+ chunk_size = gr.Number(label="Chunk Size", value=1000)
94
+ chunk_overlap = gr.Number(label="Chunk Overlap", value=500)
95
+ # load file
96
+ folder_input = gr.File(file_count="directory", label="Upload Folder")
97
+ # Add status display
98
+ status_text = gr.Textbox(
99
+ label="Status",
100
+ interactive=False,
101
+ show_label=False
102
+ )
103
+ # load button
104
+ load_btn = gr.Button("Load Data", variant="primary")
105
+
106
+ # Right Column - Chat
107
+ with gr.Column(scale=3, visible=False) as chat_col:
108
+ gr.Markdown("## Chat Interface")
109
+ chatbot = gr.Chatbot(
110
+ label="Document Analysis Chat",
111
+ type="tuples",
112
+ bubble_full_width=False, # Prevent stretching of messages
113
+ render_markdown=True, # Handle markdown formatting properly,
114
+ height=500,
115
+ )
116
+ msg = gr.Textbox(label="Your Question", placeholder="Type your question...")
117
+ clear_btn = gr.Button("Clear Chat", variant="secondary")
118
+
119
+ # Loading indicators - update to handle multiple outputs
120
+ load_btn.click(
121
+ fn=load_data,
122
+ inputs=[chunk_size, chunk_overlap, folder_input, data_store],
123
+ outputs=[data_store, status_text],
124
+ ).then(fn=lambda: gr.Column(visible=True), outputs=chat_col)
125
+
126
+ # Chat interaction
127
+ msg.submit(
128
+ fn=chat_response,
129
+ inputs=[msg, chatbot, data_store],
130
+ outputs=[chatbot],
131
+ ).then(lambda: "", None, msg)
132
+
133
+ # Clear chat
134
+ clear_btn.click(lambda: [], None, chatbot)
135
+
136
+ if __name__ == "__main__":
137
+ demo.launch()
create_db.py DELETED
@@ -1,42 +0,0 @@
1
- from langchain_community.document_loaders import DirectoryLoader
2
- from langchain_text_splitters import RecursiveCharacterTextSplitter
3
- from langchain_huggingface import HuggingFaceEmbeddings
4
- from langchain_chroma import Chroma
5
- import os
6
- import shutil
7
-
8
- CHROMA_PATH = "chromadb/"
9
-
10
- INPUT_PATH = "./data/books/dracula_segmented/"
11
- INPUT_GLOB = "*.txt"
12
-
13
- # free models
14
- MODEL_NAME = "Alibaba-NLP/gte-multilingual-base"
15
-
16
- # setup embeddings
17
- embeddings = HuggingFaceEmbeddings(
18
- model_name=MODEL_NAME,
19
- model_kwargs={"device": "cuda", "trust_remote_code": True},
20
- encode_kwargs={"normalize_embeddings": True},
21
- )
22
-
23
- # load documents
24
- raw_documents = DirectoryLoader(INPUT_PATH, glob=INPUT_GLOB).load()
25
- text_splitter = RecursiveCharacterTextSplitter(
26
- chunk_size=1000, chunk_overlap=500, length_function=len, add_start_index=True
27
- )
28
- documents = text_splitter.split_documents(raw_documents)
29
- print(f"Split {len(raw_documents)} documents into {len(documents)} chunks.")
30
-
31
- # Clear out the database first.
32
- if os.path.exists(CHROMA_PATH):
33
- shutil.rmtree(CHROMA_PATH)
34
-
35
- # Create a new DB from the documents.
36
- db = Chroma.from_documents(
37
- documents,
38
- embeddings,
39
- persist_directory=CHROMA_PATH,
40
- collection_metadata={"hnsw:space": "cosine"},
41
- )
42
- print(f"Saved {len(documents)} chunks to {CHROMA_PATH}.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
requirements.txt CHANGED
@@ -4,6 +4,8 @@ langchain-huggingface==0.1.2
4
  langchain-chroma==0.2.2 # vectorized documents for query
5
  sentence-transformers==3.4.1 # free embeddings
6
  unstructured==0.16.23
 
 
7
  # QWEN libs
8
  torch==2.4.1
9
  triton==3.0.0
 
4
  langchain-chroma==0.2.2 # vectorized documents for query
5
  sentence-transformers==3.4.1 # free embeddings
6
  unstructured==0.16.23
7
+ unstructured[md]==0.16.23
8
+ gradio==5.19.0
9
  # QWEN libs
10
  torch==2.4.1
11
  triton==3.0.0
run.sh CHANGED
@@ -13,6 +13,7 @@ fi
13
  eval "$($CONDA_PATH shell.bash hook)"
14
 
15
  conda activate specialist
16
- python query.py
 
17
  conda deactivate
18
  echo "Completed."
 
13
  eval "$($CONDA_PATH shell.bash hook)"
14
 
15
  conda activate specialist
16
+ # python Chroma.py
17
+ python LangChain.py
18
  conda deactivate
19
  echo "Completed."