jzou19950715 commited on
Commit
78841ad
·
verified ·
1 Parent(s): 6bf58ba

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +616 -223
app.py CHANGED
@@ -4,48 +4,153 @@ import logging
4
  from pathlib import Path
5
  import json
6
  from datetime import datetime
7
- from typing import List, Dict, Any, Optional
 
8
 
9
- # Configure logging
10
- logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
11
- logger = logging.getLogger(__name__)
 
12
 
13
- # Importing necessary libraries
14
- import torch
15
- import numpy as np
16
- from sentence_transformers import SentenceTransformer
17
- import chromadb
18
- from chromadb.utils import embedding_functions
19
- import gradio as gr
20
- from openai import OpenAI
21
- import google.generativeai as genai
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
 
23
- # Configuration class
24
  class Config:
25
- """Configuration for vector store and RAG"""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
  def __init__(self,
27
- local_dir: str = ".",
28
  embedding_model: str = "all-MiniLM-L6-v2",
29
- collection_name: str = "markdown_docs"):
 
 
 
 
 
 
30
  self.local_dir = local_dir
31
  self.embedding_model = embedding_model
32
  self.collection_name = collection_name
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
 
34
- # Embedding engine
35
  class EmbeddingEngine:
36
- """Handle embeddings with a lightweight model"""
 
 
 
 
 
 
 
 
 
 
 
37
 
38
  def __init__(self, model_name="all-MiniLM-L6-v2"):
 
 
 
 
 
 
 
 
 
39
  # Use GPU if available
40
  self.device = "cuda" if torch.cuda.is_available() else "cpu"
41
- logger.info(f"Using device: {self.device}")
42
 
43
  # Try multiple model options in order of preference
44
  model_options = [
45
  model_name,
46
- "all-MiniLM-L6-v2",
47
- "paraphrase-MiniLM-L3-v2",
48
- "all-mpnet-base-v2" # Higher quality but larger model
49
  ]
50
 
51
  self.model = None
@@ -53,47 +158,99 @@ class EmbeddingEngine:
53
  # Try each model in order until one works
54
  for model_option in model_options:
55
  try:
56
- logger.info(f"Attempting to load model: {model_option}")
57
  self.model = SentenceTransformer(model_option)
58
 
59
  # Move model to device
60
  self.model.to(self.device)
61
 
62
- logger.info(f"Successfully loaded model: {model_option}")
63
  self.model_name = model_option
64
  self.vector_size = self.model.get_sentence_embedding_dimension()
 
65
  break
66
 
67
  except Exception as e:
68
- logger.warning(f"Failed to load model {model_option}: {str(e)}")
69
 
70
  if self.model is None:
71
- logger.error("Failed to load any embedding model. Exiting.")
72
- sys.exit(1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73
 
74
  class VectorStoreManager:
75
- """Manage Chroma vector store operations - upload, query, etc."""
 
 
 
 
 
 
 
 
 
 
 
76
 
77
  def __init__(self, config: Config):
 
 
 
 
 
 
 
 
 
78
  self.config = config
79
 
80
  # Initialize Chroma client (local persistence)
81
  logger.info(f"Initializing Chroma at {config.local_dir}")
82
- self.client = chromadb.PersistentClient(path=config.local_dir)
 
 
 
 
 
 
83
 
84
  # Get or create collection
85
  try:
86
  # Initialize embedding model
87
  logger.info("Loading embedding model...")
88
  self.embedding_engine = EmbeddingEngine(config.embedding_model)
89
- logger.info(f"Using model: {self.embedding_engine.model_name}")
90
 
91
  # Create embedding function
92
  sentence_transformer_ef = embedding_functions.SentenceTransformerEmbeddingFunction(
93
  model_name=self.embedding_engine.model_name
94
  )
95
 
96
- # Try to get existing collection
97
  try:
98
  self.collection = self.client.get_collection(
99
  name=config.collection_name,
@@ -101,7 +258,7 @@ class VectorStoreManager:
101
  )
102
  logger.info(f"Using existing collection: {config.collection_name}")
103
  except Exception as e:
104
- logger.error(f"Error getting collection: {e}")
105
  # Attempt to get a list of available collections
106
  collections = self.client.list_collections()
107
  if collections:
@@ -122,19 +279,33 @@ class VectorStoreManager:
122
  logger.info(f"Created new collection: {config.collection_name}")
123
 
124
  except Exception as e:
125
- logger.error(f"Error initializing Chroma collection: {e}")
126
- sys.exit(1)
 
127
 
128
  def query(self, query_text: str, n_results: int = 5) -> List[Dict]:
129
  """
130
- Query the vector store with a text query
 
 
 
 
 
 
 
131
  """
 
 
 
 
132
  try:
 
 
133
  # Query the collection
134
  search_results = self.collection.query(
135
  query_texts=[query_text],
136
  n_results=n_results,
137
- include=["documents", "metadatas", "distances"]
138
  )
139
 
140
  # Format results
@@ -143,26 +314,106 @@ class VectorStoreManager:
143
  for i in range(len(search_results["documents"][0])):
144
  results.append({
145
  'document': search_results["documents"][0][i],
146
- 'metadata': search_results["metadatas"][0][i],
147
- 'score': 1.0 - search_results["distances"][0][i] # Convert distance to similarity
 
148
  })
 
 
 
 
149
 
150
  return results
151
  except Exception as e:
152
  logger.error(f"Error querying collection: {e}")
 
153
  return []
154
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
155
  def get_statistics(self) -> Dict[str, Any]:
156
- """Get statistics about the vector store"""
157
- stats = {}
 
 
 
 
 
 
 
 
 
 
158
 
159
  try:
160
  # Get collection count
161
- collection_info = self.collection.count()
162
- stats['total_documents'] = collection_info
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
163
 
164
- # Estimate unique files - with no chunking, each document is a file
165
- stats['unique_files'] = collection_info
166
  except Exception as e:
167
  logger.error(f"Error getting statistics: {e}")
168
  stats['error'] = str(e)
@@ -170,280 +421,422 @@ class VectorStoreManager:
170
  return stats
171
 
172
  class RAGSystem:
173
- """Retrieval-Augmented Generation with multiple LLM providers"""
 
 
 
 
174
 
175
- def __init__(self, vector_store: VectorStoreManager):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
176
  self.vector_store = vector_store
 
177
  self.openai_client = None
178
  self.gemini_configured = False
 
 
179
 
180
- def setup_openai(self, api_key: str):
181
- """Set up OpenAI client with API key"""
 
 
 
 
 
 
 
 
 
 
 
 
182
  try:
 
183
  self.openai_client = OpenAI(api_key=api_key)
 
 
 
 
 
 
 
 
 
 
184
  return True
185
  except Exception as e:
186
  logger.error(f"Error initializing OpenAI client: {e}")
 
187
  return False
188
 
189
- def setup_gemini(self, api_key: str):
190
- """Set up Gemini with API key"""
 
 
 
 
 
 
 
 
 
 
 
 
191
  try:
 
192
  genai.configure(api_key=api_key)
 
 
 
 
 
193
  self.gemini_configured = True
 
194
  return True
195
  except Exception as e:
196
  logger.error(f"Error configuring Gemini: {e}")
 
197
  return False
198
 
199
  def format_context(self, documents: List[Dict]) -> str:
200
- """Format retrieved documents into context for the LLM"""
 
 
 
 
 
 
 
 
201
  if not documents:
 
202
  return "No relevant documents found."
203
 
 
204
  context_parts = []
 
205
  for i, doc in enumerate(documents):
206
  metadata = doc['metadata']
 
207
  title = metadata.get('title', metadata.get('filename', 'Unknown document'))
 
 
 
 
 
 
 
 
 
208
 
209
  # For readability, limit length of context document
210
  doc_text = doc['document']
211
- if len(doc_text) > 10000: # Limit long documents in context
212
- doc_text = doc_text[:10000] + "... [Document truncated for context]"
213
 
214
- context_parts.append(f"Document {i+1} - {title}:\n{doc_text}\n")
 
 
 
215
 
216
- return "\n".join(context_parts)
217
 
218
  def generate_response_openai(self, query: str, context: str) -> str:
219
- """Generate a response using OpenAI model with context"""
 
 
 
 
 
 
 
 
 
220
  if not self.openai_client:
 
221
  return "Error: OpenAI API key not configured. Please enter an API key in the API key field."
222
 
223
  system_prompt = """
224
- You are a helpful assistant that answers questions based on the context provided.
225
- Use the information from the context to answer the user's question.
226
- If the context doesn't contain the information needed, say so clearly.
227
- Always cite the specific sections from the context that you used in your answer.
 
 
 
 
 
 
 
 
228
  """
229
 
230
  try:
 
 
 
231
  response = self.openai_client.chat.completions.create(
232
- model="gpt-4o-mini", # Use GPT-4o mini
233
  messages=[
234
  {"role": "system", "content": system_prompt},
235
  {"role": "user", "content": f"Context:\n{context}\n\nQuestion: {query}"}
236
  ],
237
- temperature=0.3, # Lower temperature for more factual responses
238
- max_tokens=1000,
239
  )
240
- return response.choices[0].message.content
 
 
 
 
 
241
  except Exception as e:
242
- logger.error(f"Error generating response with OpenAI: {e}")
243
- return f"Error generating response with OpenAI: {str(e)}"
 
244
 
245
  def generate_response_gemini(self, query: str, context: str) -> str:
246
- """Generate a response using Gemini with context"""
 
 
 
 
 
 
 
 
 
247
  if not self.gemini_configured:
 
248
  return "Error: Google AI API key not configured. Please enter an API key in the API key field."
249
 
250
- prompt = f"""
251
  You are a highly supportive and insightful assistant dedicated to providing clear, helpful, and well-structured answers based on the given context. Your goal is to ensure the user receives a thorough, encouraging, and informative response that directly addresses their question.
252
 
253
  **Guidelines for Your Response:**
254
- - Use the **context** to form a detailed and well-reasoned answer.
255
  - If the context lacks sufficient information, state it clearly while offering general insights or related knowledge.
256
- - Cite specific sections from the context that contribute to your response.
257
  - Maintain a **friendly, professional, and supportive** tone that encourages user engagement.
258
  - Aim for **clarity and depth**, breaking down complex ideas into easy-to-understand explanations.
259
- - Strive for a response length of **300-500 words**, ensuring **both completeness and readability**.
 
 
260
 
261
  **Context:**
262
  {context}
263
 
264
- **Users Question:**
265
  {query}
266
 
267
  **Your Response:**
268
- """
269
 
270
  try:
271
- model = genai.GenerativeModel('gemini-1.5-flash')
272
- response = model.generate_content(prompt)
273
- return response.text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
274
  except Exception as e:
275
- logger.error(f"Error generating response with Gemini: {e}")
276
- return f"Error generating response with Gemini: {str(e)}"
 
277
 
278
- def query_and_generate(self, query: str, n_results: int = 5, model: str = "openai") -> str:
279
- """Retrieve relevant documents and generate a response using the specified model"""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
280
  # Query vector store
281
  documents = self.vector_store.query(query, n_results=n_results)
282
 
 
 
 
 
 
 
 
 
 
 
 
 
283
  if not documents:
284
- return "No relevant documents found to answer your question."
 
285
 
286
  # Format context
287
  context = self.format_context(documents)
288
 
289
  # Generate response with the appropriate model
290
  if model == "openai":
291
- return self.generate_response_openai(query, context)
292
  elif model == "gemini":
293
- return self.generate_response_gemini(query, context)
294
  else:
295
- return f"Unknown model: {model}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
296
 
297
- # Main function to run the application
298
  def main():
299
- # Initialize the system with current directory as the Chroma location
300
- config = Config(
301
- local_dir=".", # Look for Chroma files in current directory
302
- collection_name="markdown_docs"
303
- )
 
 
 
 
 
 
 
 
 
 
 
 
304
 
305
  try:
306
  # Initialize vector store manager with existing collection
307
  vector_store = VectorStoreManager(config)
308
 
309
  # Initialize RAG system without API keys initially
310
- rag_system = RAGSystem(vector_store)
311
 
312
  # Create the Gradio interface
313
- with gr.Blocks(title="Document RAG System") as app:
314
- gr.Markdown("# Document RAG System")
 
315
 
316
  with gr.Row():
317
  with gr.Column(scale=1):
318
  # API Keys and model selection
319
- model_choice = gr.Radio(
320
- choices=["openai", "gemini"],
321
- value="openai",
322
- label="Choose LLM Provider",
323
- info="Select which model to use (GPT-4o mini or Gemini 1.5 Flash)"
324
- )
325
-
326
- api_key_input = gr.Textbox(
327
- label="API Key",
328
- placeholder="Enter your API key here...",
329
- type="password"
330
- )
331
-
332
- save_key_button = gr.Button("Save API Key", variant="primary")
333
- api_status = gr.Markdown("")
 
 
 
334
 
335
  # Search controls
336
- num_results = gr.Slider(
337
- minimum=1,
338
- maximum=10,
339
- value=10,
340
- step=1,
341
- label="Number of documents to retrieve"
342
- )
343
-
344
- # Database stats
345
- gr.Markdown("### Database Statistics")
346
- stats_display = gr.Textbox(
347
- label="",
348
- value=get_db_stats(vector_store),
349
- lines=2
350
- )
351
- refresh_button = gr.Button("Refresh Stats")
352
-
353
- with gr.Column(scale=2):
354
- # Query and response
355
- query_input = gr.Textbox(
356
- label="Your Question",
357
- placeholder="Ask a question about your documents...",
358
- lines=2
359
- )
360
-
361
- query_button = gr.Button("Ask Question", variant="primary")
362
-
363
- gr.Markdown("### Response")
364
- response_output = gr.Markdown()
365
-
366
- gr.Markdown("### Document Search Results")
367
- search_output = gr.Markdown()
368
-
369
- # Function to update API key based on selected model
370
- def update_api_key(api_key, model):
371
- if model == "openai":
372
- success = rag_system.setup_openai(api_key)
373
- model_name = "OpenAI GPT-4o mini"
374
- else:
375
- success = rag_system.setup_gemini(api_key)
376
- model_name = "Google Gemini 1.5 Flash"
377
-
378
- if success:
379
- return f"✅ {model_name} API key configured successfully"
380
- else:
381
- return f"❌ Failed to configure {model_name} API key"
382
-
383
- # Query function that returns both response and search results
384
- def query_and_search(query, n_results, model):
385
- # Get search results first
386
- results = vector_store.query(query, n_results=int(n_results))
387
-
388
- # Format search results
389
- formatted_results = []
390
- for i, res in enumerate(results):
391
- metadata = res['metadata']
392
- title = metadata.get('title', metadata.get('filename', 'Unknown'))
393
- preview = res['document'][:500] + '...' if len(res['document']) > 500 else res['document']
394
- formatted_results.append(f"**Result {i+1}** (Similarity: {res['score']:.2f})\n"
395
- f"**Source:** {title}\n"
396
- f"**Preview:**\n{preview}\n\n---\n")
397
-
398
- search_output_text = "\n".join(formatted_results) if formatted_results else "No results found."
399
-
400
- # Generate response if we have results
401
- response = "No documents found to answer your question."
402
- if results:
403
- context = rag_system.format_context(results)
404
- if model == "openai":
405
- response = rag_system.generate_response_openai(query, context)
406
- else:
407
- response = rag_system.generate_response_gemini(query, context)
408
-
409
- return response, search_output_text
410
-
411
- # Set up events
412
- save_key_button.click(
413
- fn=update_api_key,
414
- inputs=[api_key_input, model_choice],
415
- outputs=api_status
416
- )
417
-
418
- query_button.click(
419
- fn=query_and_search,
420
- inputs=[query_input, num_results, model_choice],
421
- outputs=[response_output, search_output]
422
- )
423
-
424
- refresh_button.click(
425
- fn=lambda: get_db_stats(vector_store),
426
- inputs=None,
427
- outputs=stats_display
428
- )
429
-
430
- # Launch the interface
431
- app.launch()
432
-
433
- except Exception as e:
434
- logger.error(f"Error initializing application: {e}")
435
- print(f"Error: {e}")
436
- sys.exit(1)
437
-
438
- # Helper function to get database stats
439
- def get_db_stats(vector_store):
440
- """Function to get vector store statistics"""
441
- try:
442
- stats = vector_store.get_statistics()
443
- return f"Total documents: {stats.get('total_documents', 0)}"
444
- except Exception as e:
445
- logger.error(f"Error getting statistics: {e}")
446
- return "Error getting database statistics"
447
-
448
- if __name__ == "__main__":
449
- main()
 
4
  from pathlib import Path
5
  import json
6
  from datetime import datetime
7
+ from typing import List, Dict, Any, Optional, Tuple, Union
8
+ import traceback
9
 
10
+ # Configure detailed logging with file output
11
+ LOG_DIR = "logs"
12
+ os.makedirs(LOG_DIR, exist_ok=True)
13
+ log_file = os.path.join(LOG_DIR, f"rag_system_{datetime.now().strftime('%Y%m%d_%H%M%S')}.log")
14
 
15
+ # Set up root logger with both file and console handlers
16
+ logging.basicConfig(
17
+ level=logging.INFO,
18
+ format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
19
+ handlers=[
20
+ logging.FileHandler(log_file),
21
+ logging.StreamHandler(sys.stdout)
22
+ ]
23
+ )
24
+ logger = logging.getLogger("rag_system")
25
+ logger.info(f"Starting RAG system. Log file: {log_file}")
26
+
27
+ # Importing necessary libraries with error handling
28
+ try:
29
+ import torch
30
+ import numpy as np
31
+ from sentence_transformers import SentenceTransformer
32
+ import chromadb
33
+ from chromadb.utils import embedding_functions
34
+ import gradio as gr
35
+ from openai import OpenAI
36
+ import google.generativeai as genai
37
+ logger.info("All required libraries successfully imported")
38
+ except ImportError as e:
39
+ logger.critical(f"Failed to import required libraries: {e}")
40
+ print(f"ERROR: Missing required libraries. Please install with: pip install -r requirements.txt")
41
+ print(f"Specific error: {e}")
42
+ sys.exit(1)
43
+
44
+ # Version info for tracking
45
+ VERSION = "1.0.0"
46
+ logger.info(f"RAG System Version: {VERSION}")
47
 
 
48
  class Config:
49
+ """
50
+ Configuration for vector store and RAG system.
51
+
52
+ This class centralizes all configuration parameters for the application,
53
+ making it easier to modify settings and ensure consistency.
54
+
55
+ Attributes:
56
+ local_dir (str): Directory for ChromaDB persistence
57
+ embedding_model (str): Name of the embedding model to use
58
+ collection_name (str): Name of the ChromaDB collection
59
+ default_top_k (int): Default number of results to return
60
+ openai_model (str): Default OpenAI model to use
61
+ gemini_model (str): Default Gemini model to use
62
+ temperature (float): Temperature setting for LLM generation
63
+ max_tokens (int): Maximum tokens for LLM response
64
+ system_name (str): Name of the system for UI
65
+ """
66
+
67
  def __init__(self,
68
+ local_dir: str = "./chroma_db",
69
  embedding_model: str = "all-MiniLM-L6-v2",
70
+ collection_name: str = "markdown_docs",
71
+ default_top_k: int = 5,
72
+ openai_model: str = "gpt-4o-mini",
73
+ gemini_model: str = "gemini-1.5-flash",
74
+ temperature: float = 0.3,
75
+ max_tokens: int = 1000,
76
+ system_name: str = "Document RAG System"):
77
  self.local_dir = local_dir
78
  self.embedding_model = embedding_model
79
  self.collection_name = collection_name
80
+ self.default_top_k = default_top_k
81
+ self.openai_model = openai_model
82
+ self.gemini_model = gemini_model
83
+ self.temperature = temperature
84
+ self.max_tokens = max_tokens
85
+ self.system_name = system_name
86
+
87
+ # Create local directory if it doesn't exist
88
+ os.makedirs(local_dir, exist_ok=True)
89
+
90
+ logger.info(f"Initialized configuration: {self.__dict__}")
91
+
92
+ def to_dict(self) -> Dict[str, Any]:
93
+ """Convert configuration to dictionary for serialization"""
94
+ return self.__dict__
95
+
96
+ @classmethod
97
+ def from_file(cls, config_path: str) -> 'Config':
98
+ """Load configuration from JSON file"""
99
+ try:
100
+ with open(config_path, 'r') as f:
101
+ config_dict = json.load(f)
102
+ logger.info(f"Loaded configuration from {config_path}")
103
+ return cls(**config_dict)
104
+ except Exception as e:
105
+ logger.error(f"Failed to load configuration from {config_path}: {e}")
106
+ logger.info("Using default configuration")
107
+ return cls()
108
+
109
+ def save_to_file(self, config_path: str) -> bool:
110
+ """Save configuration to JSON file"""
111
+ try:
112
+ with open(config_path, 'w') as f:
113
+ json.dump(self.to_dict(), f, indent=2)
114
+ logger.info(f"Saved configuration to {config_path}")
115
+ return True
116
+ except Exception as e:
117
+ logger.error(f"Failed to save configuration to {config_path}: {e}")
118
+ return False
119
 
 
120
  class EmbeddingEngine:
121
+ """
122
+ Handle embeddings with a lightweight model.
123
+
124
+ This class manages the embedding model used to convert text to vector
125
+ representations for semantic search.
126
+
127
+ Attributes:
128
+ model (SentenceTransformer): The loaded embedding model
129
+ model_name (str): Name of the successfully loaded model
130
+ vector_size (int): Dimension of the embedding vectors
131
+ device (str): Device used for inference ('cuda' or 'cpu')
132
+ """
133
 
134
  def __init__(self, model_name="all-MiniLM-L6-v2"):
135
+ """
136
+ Initialize the embedding engine with the specified model.
137
+
138
+ Args:
139
+ model_name (str): Name of the embedding model to load
140
+
141
+ Raises:
142
+ SystemExit: If no embedding model could be loaded
143
+ """
144
  # Use GPU if available
145
  self.device = "cuda" if torch.cuda.is_available() else "cpu"
146
+ logger.info(f"Using device for embeddings: {self.device}")
147
 
148
  # Try multiple model options in order of preference
149
  model_options = [
150
  model_name,
151
+ "all-MiniLM-L6-v2", # Good balance of speed and quality
152
+ "paraphrase-MiniLM-L3-v2", # Faster but less accurate
153
+ "all-mpnet-base-v2" # Higher quality but larger model
154
  ]
155
 
156
  self.model = None
 
158
  # Try each model in order until one works
159
  for model_option in model_options:
160
  try:
161
+ logger.info(f"Attempting to load embedding model: {model_option}")
162
  self.model = SentenceTransformer(model_option)
163
 
164
  # Move model to device
165
  self.model.to(self.device)
166
 
167
+ logger.info(f"Successfully loaded embedding model: {model_option}")
168
  self.model_name = model_option
169
  self.vector_size = self.model.get_sentence_embedding_dimension()
170
+ logger.info(f"Embedding vector size: {self.vector_size}")
171
  break
172
 
173
  except Exception as e:
174
+ logger.warning(f"Failed to load embedding model {model_option}: {str(e)}")
175
 
176
  if self.model is None:
177
+ error_msg = "Failed to load any embedding model. Please check your internet connection or install models locally."
178
+ logger.critical(error_msg)
179
+ raise SystemExit(error_msg)
180
+
181
+ def embed(self, texts: List[str]) -> np.ndarray:
182
+ """
183
+ Generate embeddings for a list of texts.
184
+
185
+ Args:
186
+ texts (List[str]): List of texts to embed
187
+
188
+ Returns:
189
+ np.ndarray: Array of embeddings
190
+
191
+ Raises:
192
+ ValueError: If the input is invalid
193
+ RuntimeError: If embedding fails
194
+ """
195
+ if not texts:
196
+ raise ValueError("Cannot embed empty list of texts")
197
+
198
+ try:
199
+ embeddings = self.model.encode(texts, convert_to_numpy=True)
200
+ return embeddings
201
+ except Exception as e:
202
+ logger.error(f"Error generating embeddings: {e}")
203
+ raise RuntimeError(f"Failed to generate embeddings: {e}")
204
 
205
  class VectorStoreManager:
206
+ """
207
+ Manage Chroma vector store operations - upload, query, etc.
208
+
209
+ This class provides an interface to the ChromaDB vector database,
210
+ handling document storage, retrieval, and management.
211
+
212
+ Attributes:
213
+ config (Config): Configuration parameters
214
+ client (chromadb.PersistentClient): ChromaDB client
215
+ collection (chromadb.Collection): The active ChromaDB collection
216
+ embedding_engine (EmbeddingEngine): Engine for generating embeddings
217
+ """
218
 
219
  def __init__(self, config: Config):
220
+ """
221
+ Initialize the vector store manager.
222
+
223
+ Args:
224
+ config (Config): Configuration parameters
225
+
226
+ Raises:
227
+ SystemExit: If the vector store cannot be initialized
228
+ """
229
  self.config = config
230
 
231
  # Initialize Chroma client (local persistence)
232
  logger.info(f"Initializing Chroma at {config.local_dir}")
233
+ try:
234
+ self.client = chromadb.PersistentClient(path=config.local_dir)
235
+ logger.info("ChromaDB client initialized successfully")
236
+ except Exception as e:
237
+ error_msg = f"Failed to initialize ChromaDB client: {e}"
238
+ logger.critical(error_msg)
239
+ raise SystemExit(error_msg)
240
 
241
  # Get or create collection
242
  try:
243
  # Initialize embedding model
244
  logger.info("Loading embedding model...")
245
  self.embedding_engine = EmbeddingEngine(config.embedding_model)
246
+ logger.info(f"Using embedding model: {self.embedding_engine.model_name}")
247
 
248
  # Create embedding function
249
  sentence_transformer_ef = embedding_functions.SentenceTransformerEmbeddingFunction(
250
  model_name=self.embedding_engine.model_name
251
  )
252
 
253
+ # Try to get existing collection or create a new one
254
  try:
255
  self.collection = self.client.get_collection(
256
  name=config.collection_name,
 
258
  )
259
  logger.info(f"Using existing collection: {config.collection_name}")
260
  except Exception as e:
261
+ logger.warning(f"Error getting collection: {e}")
262
  # Attempt to get a list of available collections
263
  collections = self.client.list_collections()
264
  if collections:
 
279
  logger.info(f"Created new collection: {config.collection_name}")
280
 
281
  except Exception as e:
282
+ error_msg = f"Error initializing Chroma collection: {e}"
283
+ logger.critical(error_msg)
284
+ raise SystemExit(error_msg)
285
 
286
  def query(self, query_text: str, n_results: int = 5) -> List[Dict]:
287
  """
288
+ Query the vector store with a text query.
289
+
290
+ Args:
291
+ query_text (str): The query text
292
+ n_results (int): Number of results to return
293
+
294
+ Returns:
295
+ List[Dict]: List of results with document text, metadata, and similarity score
296
  """
297
+ if not query_text.strip():
298
+ logger.warning("Empty query received")
299
+ return []
300
+
301
  try:
302
+ logger.info(f"Querying vector store with: '{query_text[:50]}...' (top {n_results})")
303
+
304
  # Query the collection
305
  search_results = self.collection.query(
306
  query_texts=[query_text],
307
  n_results=n_results,
308
+ include=["documents", "metadatas", "distances", "embeddings"]
309
  )
310
 
311
  # Format results
 
314
  for i in range(len(search_results["documents"][0])):
315
  results.append({
316
  'document': search_results["documents"][0][i],
317
+ 'metadata': search_results["metadatas"][0][i] if search_results["metadatas"] else {},
318
+ 'score': 1.0 - search_results["distances"][0][i], # Convert distance to similarity
319
+ 'distance': search_results["distances"][0][i]
320
  })
321
+
322
+ logger.info(f"Found {len(results)} results for query")
323
+ else:
324
+ logger.info("No results found for query")
325
 
326
  return results
327
  except Exception as e:
328
  logger.error(f"Error querying collection: {e}")
329
+ logger.debug(traceback.format_exc())
330
  return []
331
 
332
+ def add_document(self,
333
+ document: str,
334
+ doc_id: str,
335
+ metadata: Dict[str, Any]) -> bool:
336
+ """
337
+ Add a document to the vector store.
338
+
339
+ Args:
340
+ document (str): The document text
341
+ doc_id (str): Unique identifier for the document
342
+ metadata (Dict[str, Any]): Metadata about the document
343
+
344
+ Returns:
345
+ bool: True if successful, False otherwise
346
+ """
347
+ try:
348
+ logger.info(f"Adding document '{doc_id}' to vector store")
349
+
350
+ # Add the document to the collection
351
+ self.collection.add(
352
+ documents=[document],
353
+ ids=[doc_id],
354
+ metadatas=[metadata]
355
+ )
356
+
357
+ logger.info(f"Successfully added document '{doc_id}'")
358
+ return True
359
+ except Exception as e:
360
+ logger.error(f"Error adding document to collection: {e}")
361
+ return False
362
+
363
+ def delete_document(self, doc_id: str) -> bool:
364
+ """
365
+ Delete a document from the vector store.
366
+
367
+ Args:
368
+ doc_id (str): ID of the document to delete
369
+
370
+ Returns:
371
+ bool: True if successful, False otherwise
372
+ """
373
+ try:
374
+ logger.info(f"Deleting document '{doc_id}' from vector store")
375
+ self.collection.delete(ids=[doc_id])
376
+ logger.info(f"Successfully deleted document '{doc_id}'")
377
+ return True
378
+ except Exception as e:
379
+ logger.error(f"Error deleting document from collection: {e}")
380
+ return False
381
+
382
  def get_statistics(self) -> Dict[str, Any]:
383
+ """
384
+ Get statistics about the vector store.
385
+
386
+ Returns:
387
+ Dict[str, Any]: Statistics about the vector store
388
+ """
389
+ stats = {
390
+ 'collection_name': self.config.collection_name,
391
+ 'embedding_model': self.embedding_engine.model_name,
392
+ 'embedding_dimensions': self.embedding_engine.vector_size,
393
+ 'device': self.embedding_engine.device
394
+ }
395
 
396
  try:
397
  # Get collection count
398
+ collection_count = self.collection.count()
399
+ stats['total_documents'] = collection_count
400
+
401
+ # Get unique metadata values
402
+ if collection_count > 0:
403
+ try:
404
+ # Get a sample of document metadata
405
+ sample_results = self.collection.get(limit=min(collection_count, 100))
406
+ if sample_results and 'metadatas' in sample_results and sample_results['metadatas']:
407
+ # Count unique files if filename exists in metadata
408
+ filenames = set()
409
+ for metadata in sample_results['metadatas']:
410
+ if 'filename' in metadata:
411
+ filenames.add(metadata['filename'])
412
+ stats['unique_files'] = len(filenames)
413
+ except Exception as e:
414
+ logger.warning(f"Error getting metadata statistics: {e}")
415
 
416
+ logger.info(f"Vector store statistics: {stats}")
 
417
  except Exception as e:
418
  logger.error(f"Error getting statistics: {e}")
419
  stats['error'] = str(e)
 
421
  return stats
422
 
423
  class RAGSystem:
424
+ """
425
+ Retrieval-Augmented Generation with multiple LLM providers.
426
+
427
+ This class handles the RAG workflow: retrieval of relevant documents,
428
+ formatting context, and generating responses with different LLM providers.
429
 
430
+ Attributes:
431
+ vector_store (VectorStoreManager): Manager for vector store operations
432
+ openai_client (Optional[OpenAI]): OpenAI client
433
+ gemini_configured (bool): Whether Gemini API is configured
434
+ config (Config): Configuration parameters
435
+ """
436
+
437
+ def __init__(self, vector_store: VectorStoreManager, config: Config):
438
+ """
439
+ Initialize the RAG system.
440
+
441
+ Args:
442
+ vector_store (VectorStoreManager): Vector store manager
443
+ config (Config): Configuration parameters
444
+ """
445
  self.vector_store = vector_store
446
+ self.config = config
447
  self.openai_client = None
448
  self.gemini_configured = False
449
+
450
+ logger.info("Initialized RAG system")
451
 
452
+ def setup_openai(self, api_key: str) -> bool:
453
+ """
454
+ Set up OpenAI client with API key.
455
+
456
+ Args:
457
+ api_key (str): OpenAI API key
458
+
459
+ Returns:
460
+ bool: True if successful, False otherwise
461
+ """
462
+ if not api_key.strip():
463
+ logger.warning("Empty OpenAI API key provided")
464
+ return False
465
+
466
  try:
467
+ logger.info("Setting up OpenAI client")
468
  self.openai_client = OpenAI(api_key=api_key)
469
+ # Test the API key with a simple request
470
+ response = self.openai_client.chat.completions.create(
471
+ model=self.config.openai_model,
472
+ messages=[
473
+ {"role": "system", "content": "You are a helpful assistant."},
474
+ {"role": "user", "content": "Test connection"}
475
+ ],
476
+ max_tokens=10
477
+ )
478
+ logger.info("OpenAI client configured successfully")
479
  return True
480
  except Exception as e:
481
  logger.error(f"Error initializing OpenAI client: {e}")
482
+ self.openai_client = None
483
  return False
484
 
485
+ def setup_gemini(self, api_key: str) -> bool:
486
+ """
487
+ Set up Gemini with API key.
488
+
489
+ Args:
490
+ api_key (str): Google AI API key
491
+
492
+ Returns:
493
+ bool: True if successful, False otherwise
494
+ """
495
+ if not api_key.strip():
496
+ logger.warning("Empty Gemini API key provided")
497
+ return False
498
+
499
  try:
500
+ logger.info("Setting up Gemini client")
501
  genai.configure(api_key=api_key)
502
+
503
+ # Test the API key with a simple request
504
+ model = genai.GenerativeModel(self.config.gemini_model)
505
+ response = model.generate_content("Test connection")
506
+
507
  self.gemini_configured = True
508
+ logger.info("Gemini client configured successfully")
509
  return True
510
  except Exception as e:
511
  logger.error(f"Error configuring Gemini: {e}")
512
+ self.gemini_configured = False
513
  return False
514
 
515
  def format_context(self, documents: List[Dict]) -> str:
516
+ """
517
+ Format retrieved documents into context for the LLM.
518
+
519
+ Args:
520
+ documents (List[Dict]): List of retrieved documents
521
+
522
+ Returns:
523
+ str: Formatted context for the LLM
524
+ """
525
  if not documents:
526
+ logger.warning("No documents provided for context formatting")
527
  return "No relevant documents found."
528
 
529
+ logger.info(f"Formatting {len(documents)} documents for context")
530
  context_parts = []
531
+
532
  for i, doc in enumerate(documents):
533
  metadata = doc['metadata']
534
+ # Extract document metadata in a robust way
535
  title = metadata.get('title', metadata.get('filename', 'Unknown document'))
536
+ source = metadata.get('source', metadata.get('path', 'Unknown source'))
537
+ date = metadata.get('date', metadata.get('created_at', 'Unknown date'))
538
+
539
+ # Format header with metadata
540
+ header = f"Document {i+1} - {title}"
541
+ if source != 'Unknown source':
542
+ header += f" (Source: {source})"
543
+ if date != 'Unknown date':
544
+ header += f" (Date: {date})"
545
 
546
  # For readability, limit length of context document
547
  doc_text = doc['document']
548
+ if len(doc_text) > 8000: # Limit long documents in context
549
+ doc_text = doc_text[:8000] + "... [Document truncated for context]"
550
 
551
+ context_parts.append(f"{header}:\n{doc_text}\n")
552
+
553
+ full_context = "\n".join(context_parts)
554
+ logger.info(f"Created context with {len(full_context)} characters")
555
 
556
+ return full_context
557
 
558
  def generate_response_openai(self, query: str, context: str) -> str:
559
+ """
560
+ Generate a response using OpenAI model with context.
561
+
562
+ Args:
563
+ query (str): User query
564
+ context (str): Formatted document context
565
+
566
+ Returns:
567
+ str: Generated response
568
+ """
569
  if not self.openai_client:
570
+ logger.warning("OpenAI API key not configured for response generation")
571
  return "Error: OpenAI API key not configured. Please enter an API key in the API key field."
572
 
573
  system_prompt = """
574
+ You are a helpful, detailed, and accurate assistant that answers questions based on the context provided.
575
+ Follow these guidelines:
576
+
577
+ 1. Use ONLY the information from the context to answer the user's question.
578
+ 2. If the context doesn't contain the information needed, say so clearly and do your best to deduce and infer the answer.
579
+ 3. Always cite the specific documents from the context that you used in your answer by referencing their number (e.g., "According to Document 1...").
580
+ 4. Organize your response in a clear, structured format with headings where appropriate.
581
+ 5. Use the best practices of writings.
582
+ 6. If the information in different documents conflicts, acknowledge this and explain the different perspectives.
583
+ 7. Be specific and detailed in your answers, focusing on accuracy over brevity.
584
+ 8. Aim to be educational and informative in your tone.
585
+ 9. You aim to write between 300-500 words of comprehensive answer to user question.
586
  """
587
 
588
  try:
589
+ logger.info(f"Generating response with OpenAI ({self.config.openai_model})")
590
+
591
+ start_time = datetime.now()
592
  response = self.openai_client.chat.completions.create(
593
+ model=self.config.openai_model,
594
  messages=[
595
  {"role": "system", "content": system_prompt},
596
  {"role": "user", "content": f"Context:\n{context}\n\nQuestion: {query}"}
597
  ],
598
+ temperature=self.config.temperature,
599
+ max_tokens=self.config.max_tokens,
600
  )
601
+
602
+ generation_time = (datetime.now() - start_time).total_seconds()
603
+ response_text = response.choices[0].message.content
604
+
605
+ logger.info(f"Generated response with OpenAI in {generation_time:.2f} seconds")
606
+ return response_text
607
  except Exception as e:
608
+ error_msg = f"Error generating response with OpenAI: {str(e)}"
609
+ logger.error(error_msg)
610
+ return f"Error: {error_msg}"
611
 
612
  def generate_response_gemini(self, query: str, context: str) -> str:
613
+ """
614
+ Generate a response using Gemini with context.
615
+
616
+ Args:
617
+ query (str): User query
618
+ context (str): Formatted document context
619
+
620
+ Returns:
621
+ str: Generated response
622
+ """
623
  if not self.gemini_configured:
624
+ logger.warning("Gemini API key not configured for response generation")
625
  return "Error: Google AI API key not configured. Please enter an API key in the API key field."
626
 
627
+ prompt = f"""
628
  You are a highly supportive and insightful assistant dedicated to providing clear, helpful, and well-structured answers based on the given context. Your goal is to ensure the user receives a thorough, encouraging, and informative response that directly addresses their question.
629
 
630
  **Guidelines for Your Response:**
631
+ - Use ONLY the information from the **context** to form a detailed and well-reasoned answer.
632
  - If the context lacks sufficient information, state it clearly while offering general insights or related knowledge.
633
+ - Cite specific sections from the context by referring to document numbers (e.g., "According to Document 1...").
634
  - Maintain a **friendly, professional, and supportive** tone that encourages user engagement.
635
  - Aim for **clarity and depth**, breaking down complex ideas into easy-to-understand explanations.
636
+ - Organize your response with headings and sections if appropriate.
637
+ - Do not make up information or use knowledge outside of the provided context.
638
+ - If information in different documents conflicts, explain the different perspectives.
639
 
640
  **Context:**
641
  {context}
642
 
643
+ **User's Question:**
644
  {query}
645
 
646
  **Your Response:**
647
+ """
648
 
649
  try:
650
+ logger.info(f"Generating response with Gemini ({self.config.gemini_model})")
651
+
652
+ start_time = datetime.now()
653
+ model = genai.GenerativeModel(self.config.gemini_model)
654
+
655
+ generation_config = {
656
+ "temperature": self.config.temperature,
657
+ "max_output_tokens": self.config.max_tokens,
658
+ "top_p": 0.9,
659
+ "top_k": 40
660
+ }
661
+
662
+ response = model.generate_content(
663
+ prompt,
664
+ generation_config=generation_config
665
+ )
666
+
667
+ generation_time = (datetime.now() - start_time).total_seconds()
668
+ response_text = response.text
669
+
670
+ logger.info(f"Generated response with Gemini in {generation_time:.2f} seconds")
671
+ return response_text
672
  except Exception as e:
673
+ error_msg = f"Error generating response with Gemini: {str(e)}"
674
+ logger.error(error_msg)
675
+ return f"Error: {error_msg}"
676
 
677
+ def query_and_generate(self,
678
+ query: str,
679
+ n_results: int = 5,
680
+ model: str = "openai") -> Tuple[str, str]:
681
+ """
682
+ Retrieve relevant documents and generate a response using the specified model.
683
+
684
+ Args:
685
+ query (str): User query
686
+ n_results (int): Number of documents to retrieve
687
+ model (str): Model provider to use ('openai' or 'gemini')
688
+
689
+ Returns:
690
+ Tuple[str, str]: (Generated response, Search results)
691
+ """
692
+ if not query.strip():
693
+ logger.warning("Empty query received")
694
+ return "Please enter a question to get a response.", "No search performed."
695
+
696
+ logger.info(f"Processing query: '{query[:50]}...' with {model} model")
697
+
698
  # Query vector store
699
  documents = self.vector_store.query(query, n_results=n_results)
700
 
701
+ # Format search results
702
+ formatted_results = []
703
+ for i, res in enumerate(documents):
704
+ metadata = res['metadata']
705
+ title = metadata.get('title', metadata.get('filename', 'Unknown'))
706
+ preview = res['document'][:500] + '...' if len(res['document']) > 500 else res['document']
707
+ formatted_results.append(f"**Result {i+1}** (Similarity: {res['score']:.2f})\n"
708
+ f"**Source:** {title}\n"
709
+ f"**Preview:**\n{preview}\n\n---\n")
710
+
711
+ search_output_text = "\n".join(formatted_results) if formatted_results else "No results found."
712
+
713
  if not documents:
714
+ logger.warning("No relevant documents found")
715
+ return "No relevant documents found to answer your question.", search_output_text
716
 
717
  # Format context
718
  context = self.format_context(documents)
719
 
720
  # Generate response with the appropriate model
721
  if model == "openai":
722
+ response = self.generate_response_openai(query, context)
723
  elif model == "gemini":
724
+ response = self.generate_response_gemini(query, context)
725
  else:
726
+ error_msg = f"Unknown model: {model}"
727
+ logger.error(error_msg)
728
+ return error_msg, search_output_text
729
+
730
+ return response, search_output_text
731
+
732
+ def get_db_stats(vector_store: VectorStoreManager) -> str:
733
+ """
734
+ Function to get vector store statistics.
735
+
736
+ Args:
737
+ vector_store (VectorStoreManager): Vector store manager
738
+
739
+ Returns:
740
+ str: Formatted statistics string
741
+ """
742
+ try:
743
+ stats = vector_store.get_statistics()
744
+ total_docs = stats.get('total_documents', 0)
745
+ unique_files = stats.get('unique_files', 'Unknown')
746
+ model = stats.get('embedding_model', 'Unknown')
747
+ device = stats.get('device', 'Unknown')
748
+
749
+ stats_text = [
750
+ f"Total documents: {total_docs}",
751
+ f"Unique files: {unique_files}",
752
+ f"Embedding model: {model}",
753
+ f"Device: {device}"
754
+ ]
755
+
756
+ return "\n".join(stats_text)
757
+ except Exception as e:
758
+ logger.error(f"Error getting statistics: {e}")
759
+ return "Error getting database statistics"
760
 
 
761
  def main():
762
+ """Main function to run the RAG application"""
763
+ print(f"Starting {CONFIG_FILE_PATH}Document RAG System v{VERSION}")
764
+ print(f"Log file: {log_file}")
765
+
766
+ # Path for configuration file
767
+ CONFIG_FILE_PATH = "rag_config.json"
768
+
769
+ # Try to load configuration from file, or use defaults
770
+ if os.path.exists(CONFIG_FILE_PATH):
771
+ config = Config.from_file(CONFIG_FILE_PATH)
772
+ else:
773
+ config = Config(
774
+ local_dir="./chroma_db", # Store Chroma files in dedicated directory
775
+ collection_name="markdown_docs"
776
+ )
777
+ # Save default configuration
778
+ config.save_to_file(CONFIG_FILE_PATH)
779
 
780
  try:
781
  # Initialize vector store manager with existing collection
782
  vector_store = VectorStoreManager(config)
783
 
784
  # Initialize RAG system without API keys initially
785
+ rag_system = RAGSystem(vector_store, config)
786
 
787
  # Create the Gradio interface
788
+ with gr.Blocks(title=config.system_name) as app:
789
+ gr.Markdown(f"# {config.system_name} v{VERSION}")
790
+ gr.Markdown("Retrieve and generate answers from your documents using AI")
791
 
792
  with gr.Row():
793
  with gr.Column(scale=1):
794
  # API Keys and model selection
795
+ with gr.Box():
796
+ gr.Markdown("### LLM Configuration")
797
+ model_choice = gr.Radio(
798
+ choices=["openai", "gemini"],
799
+ value="openai",
800
+ label="Choose LLM Provider",
801
+ info=f"Select which model to use ({config.openai_model} or {config.gemini_model})"
802
+ )
803
+
804
+ api_key_input = gr.Textbox(
805
+ label="API Key",
806
+ placeholder="Enter your API key here...",
807
+ type="password",
808
+ info="Your API key is not stored between sessions"
809
+ )
810
+
811
+ save_key_button = gr.Button("Save API Key", variant="primary")
812
+ api_status = gr.Markdown("")
813
 
814
  # Search controls
815
+ with gr.Box():
816
+ gr.Markdown("### Search Settings")
817
+ num_results = gr.Slider(
818
+ minimum=1,
819
+ maximum=20,
820
+ value=15,
821
+ step=1,
822
+ label="Number of documents to retrieve",
823
+ info="Higher values may provide more context but slower responses"
824
+ )
825
+
826
+ temperature_slider = gr.Slider(
827
+ minimum=0.0,
828
+ maximum=1.0,
829
+ value=config.temperature,
830
+ step=0.05,
831
+ label="Temperature",
832
+ info="Lower values = more factual, higher values = more creative"
833
+ )
834
+
835
+ max_tokens_slider = gr.Slider(
836
+ minimum=100,
837
+ maximum=4000,
838
+ value=config.max_tokens,
839
+ step=100,
840
+ label="Max Output Tokens",
841
+ info="Maximum length of generated response"
842
+ )