jzou19950715 commited on
Commit
1773e23
·
verified ·
1 Parent(s): 18f7800

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +453 -50
app.py CHANGED
@@ -1,64 +1,467 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
- from huggingface_hub import InferenceClient
 
3
 
4
- """
5
- For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
6
- """
7
- client = InferenceClient("HuggingFaceH4/zephyr-7b-beta")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
 
10
- def respond(
11
- message,
12
- history: list[tuple[str, str]],
13
- system_message,
14
- max_tokens,
15
- temperature,
16
- top_p,
17
- ):
18
- messages = [{"role": "system", "content": system_message}]
 
 
 
 
 
 
 
 
 
 
 
19
 
20
- for val in history:
21
- if val[0]:
22
- messages.append({"role": "user", "content": val[0]})
23
- if val[1]:
24
- messages.append({"role": "assistant", "content": val[1]})
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
 
26
- messages.append({"role": "user", "content": message})
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
 
28
- response = ""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
 
30
- for message in client.chat_completion(
31
- messages,
32
- max_tokens=max_tokens,
33
- stream=True,
34
- temperature=temperature,
35
- top_p=top_p,
36
- ):
37
- token = message.choices[0].delta.content
38
 
39
- response += token
40
- yield response
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
 
 
 
 
 
 
42
 
43
- """
44
- For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
45
- """
46
- demo = gr.ChatInterface(
47
- respond,
48
- additional_inputs=[
49
- gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
50
- gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
51
- gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
52
- gr.Slider(
53
- minimum=0.1,
54
- maximum=1.0,
55
- value=0.95,
56
- step=0.05,
57
- label="Top-p (nucleus sampling)",
58
- ),
59
- ],
60
- )
 
 
61
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
 
63
  if __name__ == "__main__":
64
- demo.launch()
 
1
+ import os
2
+ import sys
3
+ import logging
4
+ from pathlib import Path
5
+ import json
6
+ import hashlib
7
+ from datetime import datetime
8
+ import threading
9
+ import queue
10
+ from typing import List, Dict, Any, Tuple, Optional
11
+
12
+ # Configure logging
13
+ logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
14
+ logger = logging.getLogger(__name__)
15
+
16
+ # Importing necessary libraries
17
+ import torch
18
+ import numpy as np
19
+ from sentence_transformers import SentenceTransformer
20
+ import chromadb
21
+ from chromadb.utils import embedding_functions
22
  import gradio as gr
23
+ from openai import OpenAI
24
+ import google.generativeai as genai
25
 
26
+ # Configuration class
27
+ class Config:
28
+ """Configuration for vector store and RAG"""
29
+ def __init__(self,
30
+ local_dir: str = "./chroma_data",
31
+ batch_size: int = 20,
32
+ max_workers: int = 4,
33
+ embedding_model: str = "all-MiniLM-L6-v2",
34
+ collection_name: str = "markdown_docs"):
35
+ self.local_dir = local_dir
36
+ self.batch_size = batch_size
37
+ self.max_workers = max_workers
38
+ self.checkpoint_file = Path(local_dir) / "checkpoint.json"
39
+ self.embedding_model = embedding_model
40
+ self.collection_name = collection_name
41
+
42
+ # Create local directory for checkpoints and Chroma
43
+ Path(local_dir).mkdir(parents=True, exist_ok=True)
44
 
45
+ # Embedding engine
46
+ class EmbeddingEngine:
47
+ """Handle embeddings with a lightweight model"""
48
+
49
+ def __init__(self, model_name="all-MiniLM-L6-v2"):
50
+ # Use GPU if available
51
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
52
+ logger.info(f"Using device: {self.device}")
53
+
54
+ # Try multiple model options in order of preference
55
+ model_options = [
56
+ model_name,
57
+ "all-MiniLM-L6-v2",
58
+ "paraphrase-MiniLM-L3-v2",
59
+ "all-mpnet-base-v2" # Higher quality but larger model
60
+ ]
61
+
62
+ self.model = None
63
+
64
+ # Try each model in order until one works
65
+ for model_option in model_options:
66
+ try:
67
+ logger.info(f"Attempting to load model: {model_option}")
68
+ self.model = SentenceTransformer(model_option)
69
+
70
+ # Move model to device
71
+ self.model.to(self.device)
72
+
73
+ logger.info(f"Successfully loaded model: {model_option}")
74
+ self.model_name = model_option
75
+ self.vector_size = self.model.get_sentence_embedding_dimension()
76
+ break
77
+
78
+ except Exception as e:
79
+ logger.warning(f"Failed to load model {model_option}: {str(e)}")
80
+
81
+ if self.model is None:
82
+ logger.error("Failed to load any embedding model. Exiting.")
83
+ sys.exit(1)
84
 
85
+ def encode(self, text, batch_size=32):
86
+ """Get embedding for a text or list of texts"""
87
+ # Handle single text
88
+ if isinstance(text, str):
89
+ texts = [text]
90
+ else:
91
+ texts = text
92
+
93
+ # Truncate texts if necessary to avoid tokenization issues
94
+ truncated_texts = [t[:50000] if len(t) > 50000 else t for t in texts]
95
+
96
+ # Generate embeddings
97
+ try:
98
+ embeddings = self.model.encode(truncated_texts, batch_size=batch_size,
99
+ show_progress_bar=False, convert_to_numpy=True)
100
+ return embeddings
101
+ except Exception as e:
102
+ logger.error(f"Error generating embeddings: {e}")
103
+ # Return zero embeddings as fallback
104
+ return np.zeros((len(truncated_texts), self.vector_size))
105
 
106
+ class VectorStoreManager:
107
+ """Manage Chroma vector store operations - upload, query, etc."""
108
+
109
+ def __init__(self, config: Config):
110
+ self.config = config
111
+
112
+ # Initialize Chroma client (local persistence)
113
+ logger.info(f"Initializing Chroma at {config.local_dir}")
114
+ self.client = chromadb.PersistentClient(path=config.local_dir)
115
+
116
+ # Get or create collection
117
+ try:
118
+ # Initialize embedding model
119
+ logger.info("Loading embedding model...")
120
+ self.embedding_engine = EmbeddingEngine(config.embedding_model)
121
+ logger.info(f"Using model: {self.embedding_engine.model_name}")
122
+
123
+ # Create embedding function
124
+ sentence_transformer_ef = embedding_functions.SentenceTransformerEmbeddingFunction(
125
+ model_name=self.embedding_engine.model_name
126
+ )
127
+
128
+ # Try to get existing collection
129
+ try:
130
+ self.collection = self.client.get_collection(
131
+ name=config.collection_name,
132
+ embedding_function=sentence_transformer_ef
133
+ )
134
+ logger.info(f"Using existing collection: {config.collection_name}")
135
+ except:
136
+ # Create new collection if it doesn't exist
137
+ self.collection = self.client.create_collection(
138
+ name=config.collection_name,
139
+ embedding_function=sentence_transformer_ef,
140
+ metadata={"hnsw:space": "cosine"}
141
+ )
142
+ logger.info(f"Created new collection: {config.collection_name}")
143
+
144
+ except Exception as e:
145
+ logger.error(f"Error initializing Chroma collection: {e}")
146
+ sys.exit(1)
147
+
148
+ def query(self, query_text: str, n_results: int = 5) -> List[Dict]:
149
+ """
150
+ Query the vector store with a text query
151
+ """
152
+ try:
153
+ # Query the collection
154
+ search_results = self.collection.query(
155
+ query_texts=[query_text],
156
+ n_results=n_results,
157
+ include=["documents", "metadatas", "distances"]
158
+ )
159
+
160
+ # Format results
161
+ results = []
162
+ if search_results["documents"] and len(search_results["documents"][0]) > 0:
163
+ for i in range(len(search_results["documents"][0])):
164
+ results.append({
165
+ 'document': search_results["documents"][0][i],
166
+ 'metadata': search_results["metadatas"][0][i],
167
+ 'score': 1.0 - search_results["distances"][0][i] # Convert distance to similarity
168
+ })
169
+
170
+ return results
171
+ except Exception as e:
172
+ logger.error(f"Error querying collection: {e}")
173
+ return []
174
 
175
+ def get_statistics(self) -> Dict[str, Any]:
176
+ """Get statistics about the vector store"""
177
+ stats = {}
178
+
179
+ try:
180
+ # Get collection count
181
+ collection_info = self.collection.count()
182
+ stats['total_documents'] = collection_info
183
+
184
+ # Estimate unique files - with no chunking, each document is a file
185
+ stats['unique_files'] = collection_info
186
+ except Exception as e:
187
+ logger.error(f"Error getting statistics: {e}")
188
+ stats['error'] = str(e)
189
+
190
+ return stats
191
 
192
+ class RAGSystem:
193
+ """Retrieval-Augmented Generation with multiple LLM providers"""
194
+
195
+ def __init__(self, vector_store: VectorStoreManager):
196
+ self.vector_store = vector_store
197
+ self.openai_client = None
198
+ self.gemini_configured = False
199
+
200
+ def setup_openai(self, api_key: str):
201
+ """Set up OpenAI client with API key"""
202
+ try:
203
+ self.openai_client = OpenAI(api_key=api_key)
204
+ return True
205
+ except Exception as e:
206
+ logger.error(f"Error initializing OpenAI client: {e}")
207
+ return False
208
+
209
+ def setup_gemini(self, api_key: str):
210
+ """Set up Gemini with API key"""
211
+ try:
212
+ genai.configure(api_key=api_key)
213
+ self.gemini_configured = True
214
+ return True
215
+ except Exception as e:
216
+ logger.error(f"Error configuring Gemini: {e}")
217
+ return False
218
+
219
+ def format_context(self, documents: List[Dict]) -> str:
220
+ """Format retrieved documents into context for the LLM"""
221
+ if not documents:
222
+ return "No relevant documents found."
223
+
224
+ context_parts = []
225
+ for i, doc in enumerate(documents):
226
+ metadata = doc['metadata']
227
+ title = metadata.get('title', metadata.get('filename', 'Unknown document'))
228
+
229
+ # For readability, limit length of context document
230
+ doc_text = doc['document']
231
+ if len(doc_text) > 10000: # Limit long documents in context
232
+ doc_text = doc_text[:10000] + "... [Document truncated for context]"
233
+
234
+ context_parts.append(f"Document {i+1} - {title}:\n{doc_text}\n")
235
+
236
+ return "\n".join(context_parts)
237
+
238
+ def generate_response_openai(self, query: str, context: str) -> str:
239
+ """Generate a response using OpenAI model with context"""
240
+ if not self.openai_client:
241
+ return "Error: OpenAI API key not configured. Please enter an API key in the settings tab."
242
+
243
+ system_prompt = """
244
+ You are a helpful assistant that answers questions based on the context provided.
245
+ Use the information from the context to answer the user's question.
246
+ If the context doesn't contain the information needed, say so clearly.
247
+ Always cite the specific sections from the context that you used in your answer.
248
+ """
249
+
250
+ try:
251
+ response = self.openai_client.chat.completions.create(
252
+ model="gpt-4o-mini", # Use GPT-4o mini
253
+ messages=[
254
+ {"role": "system", "content": system_prompt},
255
+ {"role": "user", "content": f"Context:\n{context}\n\nQuestion: {query}"}
256
+ ],
257
+ temperature=0.3, # Lower temperature for more factual responses
258
+ max_tokens=1000,
259
+ )
260
+ return response.choices[0].message.content
261
+ except Exception as e:
262
+ logger.error(f"Error generating response with OpenAI: {e}")
263
+ return f"Error generating response with OpenAI: {str(e)}"
264
+
265
+ def generate_response_gemini(self, query: str, context: str) -> str:
266
+ """Generate a response using Gemini with context"""
267
+ if not self.gemini_configured:
268
+ return "Error: Google AI API key not configured. Please enter an API key in the settings tab."
269
+
270
+ prompt = f"""
271
+ You are a helpful assistant that answers questions based on the context provided.
272
+ Use the information from the context to answer the user's question.
273
+ If the context doesn't contain the information needed, say so clearly.
274
+ Always cite the specific sections from the context that you used in your answer.
275
+
276
+ Context:
277
+ {context}
278
+
279
+ Question: {query}
280
+ """
281
+
282
+ try:
283
+ model = genai.GenerativeModel('gemini-1.5-flash')
284
+ response = model.generate_content(prompt)
285
+ return response.text
286
+ except Exception as e:
287
+ logger.error(f"Error generating response with Gemini: {e}")
288
+ return f"Error generating response with Gemini: {str(e)}"
289
+
290
+ def query_and_generate(self, query: str, n_results: int = 5, model: str = "openai") -> str:
291
+ """Retrieve relevant documents and generate a response using the specified model"""
292
+ # Query vector store
293
+ documents = self.vector_store.query(query, n_results=n_results)
294
+
295
+ if not documents:
296
+ return "No relevant documents found to answer your question."
297
+
298
+ # Format context
299
+ context = self.format_context(documents)
300
+
301
+ # Generate response with the appropriate model
302
+ if model == "openai":
303
+ return self.generate_response_openai(query, context)
304
+ elif model == "gemini":
305
+ return self.generate_response_gemini(query, context)
306
+ else:
307
+ return f"Unknown model: {model}"
308
 
309
+ def rag_chat(query, n_results, model_choice, rag_system):
310
+ """Function to handle RAG chat queries"""
311
+ return rag_system.query_and_generate(query, n_results=int(n_results), model=model_choice)
 
 
 
 
 
312
 
313
+ def simple_query(query, n_results, vector_store):
314
+ """Function to handle simple vector store queries"""
315
+ results = vector_store.query(query, n_results=int(n_results))
316
+
317
+ # Format results for display
318
+ formatted = []
319
+ for i, res in enumerate(results):
320
+ metadata = res['metadata']
321
+ title = metadata.get('title', metadata.get('filename', 'Unknown'))
322
+ # Limit preview text for display
323
+ preview = res['document'][:800] + '...' if len(res['document']) > 800 else res['document']
324
+ formatted.append(f"**Result {i+1}** (Similarity: {res['score']:.2f})\n\n"
325
+ f"**Source:** {title}\n\n"
326
+ f"**Content:**\n{preview}\n\n"
327
+ f"---\n")
328
+
329
+ return "\n".join(formatted) if formatted else "No results found."
330
 
331
+ def get_db_stats(vector_store):
332
+ """Function to get vector store statistics"""
333
+ stats = vector_store.get_statistics()
334
+ return (f"Total documents: {stats.get('total_documents', 0)}\n"
335
+ f"Unique files: {stats.get('unique_files', 0)}")
336
 
337
+ def update_api_keys(openai_key, gemini_key, rag_system):
338
+ """Update API keys for the RAG system"""
339
+ success_msg = []
340
+
341
+ if openai_key:
342
+ if rag_system.setup_openai(openai_key):
343
+ success_msg.append(" OpenAI API key configured successfully")
344
+ else:
345
+ success_msg.append("❌ Failed to configure OpenAI API key")
346
+
347
+ if gemini_key:
348
+ if rag_system.setup_gemini(gemini_key):
349
+ success_msg.append("✅ Google AI API key configured successfully")
350
+ else:
351
+ success_msg.append(" Failed to configure Google AI API key")
352
+
353
+ if not success_msg:
354
+ return "Please enter at least one API key"
355
+
356
+ return "\n".join(success_msg)
357
 
358
+ # Main function to run the application
359
+ def main():
360
+ # Set up paths for existing Chroma database
361
+ chroma_dir = Path("./chroma_data")
362
+
363
+ # Initialize the system
364
+ config = Config(
365
+ local_dir=str(chroma_dir),
366
+ collection_name="markdown_docs"
367
+ )
368
+
369
+ # Initialize vector store manager with existing collection
370
+ vector_store = VectorStoreManager(config)
371
+
372
+ # Initialize RAG system without API keys initially
373
+ rag_system = RAGSystem(vector_store)
374
+
375
+ # Define Gradio app
376
+ def rag_chat_wrapper(query, n_results, model_choice):
377
+ return rag_chat(query, n_results, model_choice, rag_system)
378
+
379
+ def simple_query_wrapper(query, n_results):
380
+ return simple_query(query, n_results, vector_store)
381
+
382
+ def update_api_keys_wrapper(openai_key, gemini_key):
383
+ return update_api_keys(openai_key, gemini_key, rag_system)
384
+
385
+ # Create the Gradio interface
386
+ with gr.Blocks(title="Markdown RAG System") as app:
387
+ gr.Markdown("# RAG System with Multiple LLM Providers")
388
+
389
+ with gr.Tab("Chat with Documents"):
390
+ with gr.Row():
391
+ with gr.Column(scale=3):
392
+ query_input = gr.Textbox(label="Question", placeholder="Ask a question about your documents...")
393
+ num_results = gr.Slider(minimum=1, maximum=10, value=3, step=1, label="Number of documents to retrieve")
394
+ model_choice = gr.Radio(
395
+ choices=["openai", "gemini"],
396
+ value="openai",
397
+ label="Choose LLM Provider",
398
+ info="Select which model to use for generating answers"
399
+ )
400
+ query_button = gr.Button("Ask", variant="primary")
401
+
402
+ with gr.Column(scale=7):
403
+ response_output = gr.Markdown(label="Response")
404
+
405
+ # Database stats
406
+ stats_display = gr.Textbox(label="Database Statistics", value=get_db_stats(vector_store))
407
+ refresh_button = gr.Button("Refresh Statistics")
408
+
409
+ with gr.Tab("Document Search"):
410
+ search_input = gr.Textbox(label="Search Query", placeholder="Search your documents...")
411
+ search_num = gr.Slider(minimum=1, maximum=20, value=5, step=1, label="Number of results")
412
+ search_button = gr.Button("Search", variant="primary")
413
+ search_output = gr.Markdown(label="Search Results")
414
+
415
+ with gr.Tab("Settings"):
416
+ gr.Markdown("""
417
+ ## API Keys Configuration
418
+
419
+ This application can use either OpenAI's GPT-4o-mini or Google's Gemini 1.5 Flash for generating responses.
420
+ You need to provide at least one API key to use the chat functionality.
421
+ """)
422
+
423
+ openai_key_input = gr.Textbox(
424
+ label="OpenAI API Key",
425
+ placeholder="Enter your OpenAI API key here...",
426
+ type="password"
427
+ )
428
+
429
+ gemini_key_input = gr.Textbox(
430
+ label="Google AI API Key",
431
+ placeholder="Enter your Google AI API key here...",
432
+ type="password"
433
+ )
434
+
435
+ save_keys_button = gr.Button("Save API Keys", variant="primary")
436
+ api_status = gr.Markdown("")
437
+
438
+ # Set up events
439
+ query_button.click(
440
+ fn=rag_chat_wrapper,
441
+ inputs=[query_input, num_results, model_choice],
442
+ outputs=response_output
443
+ )
444
+
445
+ refresh_button.click(
446
+ fn=lambda: get_db_stats(vector_store),
447
+ inputs=None,
448
+ outputs=stats_display
449
+ )
450
+
451
+ search_button.click(
452
+ fn=simple_query_wrapper,
453
+ inputs=[search_input, search_num],
454
+ outputs=search_output
455
+ )
456
+
457
+ save_keys_button.click(
458
+ fn=update_api_keys_wrapper,
459
+ inputs=[openai_key_input, gemini_key_input],
460
+ outputs=api_status
461
+ )
462
+
463
+ # Launch the interface
464
+ app.launch()
465
 
466
  if __name__ == "__main__":
467
+ main()