jzou19950715 commited on
Commit
dcf7268
·
verified ·
1 Parent(s): ad3151d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +788 -240
app.py CHANGED
@@ -1,272 +1,820 @@
1
- def main():
2
- """Main function to run the RAG application"""
3
- # Path for configuration file
4
- CONFIG_FILE_PATH = "rag_config.json"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
- try:
7
- # Try to load configuration from file, or use defaults
8
- if os.path.exists(CONFIG_FILE_PATH):
9
- config = Config.from_file(CONFIG_FILE_PATH)
10
- else:
11
- config = Config(
12
- local_dir="./chroma_db", # Store Chroma files in dedicated directory
13
- collection_name="markdown_docs"
14
- )
15
- # Save default configuration
16
- config.save_to_file(CONFIG_FILE_PATH)
17
-
18
- print(f"Starting Document Knowledge Assistant v{VERSION}")
19
- print(f"Log file: {log_file}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
 
21
- # Initialize vector store manager with existing collection
22
- vector_store = VectorStoreManager(config)
23
 
24
- # Initialize RAG system without API keys initially
25
- rag_system = RAGSystem(vector_store, config)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
 
27
- # Create the Gradio interface with custom CSS
28
- with gr.Blocks(title="Document Knowledge Assistant", css=custom_css) as app:
29
- gr.Markdown(f"# Document Knowledge Assistant v{VERSION}")
30
- gr.Markdown("Ask questions about your documents and get comprehensive AI-powered answers")
31
 
32
- # Main layout
33
- with gr.Row():
34
- # Left column for asking questions
35
- with gr.Column(scale=3):
36
- with gr.Box():
37
- gr.Markdown("### Ask Your Question")
38
- query_input = gr.Textbox(
39
- label="",
40
- placeholder="What would you like to know about your documents?",
41
- lines=3
42
- )
43
-
44
- with gr.Row():
45
- query_button = gr.Button("Ask Question", variant="primary", scale=3)
46
- clear_button = gr.Button("Clear", variant="secondary", scale=1)
47
-
48
- with gr.Box():
49
- gr.Markdown("### Answer")
50
- response_output = gr.Markdown()
 
 
 
51
 
52
- # Right column for settings
53
- with gr.Column(scale=1):
54
- # API Keys and model selection
55
- with gr.Accordion("AI Model Settings", open=True):
56
- gr.Markdown("### AI Configuration")
57
- model_choice = gr.Radio(
58
- choices=["openai", "gemini"],
59
- value="openai",
60
- label="AI Provider",
61
- info=f"Select your preferred AI model"
62
- )
63
-
64
- api_key_input = gr.Textbox(
65
- label="API Key",
66
- placeholder="Enter your API key here...",
67
- type="password",
68
- info="Your key is not stored between sessions"
69
- )
70
-
71
- save_key_button = gr.Button("Save API Key", variant="primary")
72
- api_status = gr.Markdown("")
73
-
74
- # Advanced search controls
75
- with gr.Accordion("Advanced Settings", open=False):
76
- gr.Markdown("### Search & Response Settings")
77
- num_results = gr.Slider(
78
- minimum=3,
79
- maximum=15,
80
- value=config.default_top_k,
81
- step=1,
82
- label="Documents to search",
83
- info="Higher values provide more context"
84
- )
85
-
86
- temperature_slider = gr.Slider(
87
- minimum=0.0,
88
- maximum=1.0,
89
- value=config.temperature,
90
- step=0.05,
91
- label="Creativity",
92
- info="Lower = more factual, Higher = more creative"
93
- )
94
-
95
- max_tokens_slider = gr.Slider(
96
- minimum=500,
97
- maximum=4000,
98
- value=config.max_tokens,
99
- step=100,
100
- label="Response Length",
101
- info="Maximum words in response"
102
- )
103
-
104
- # Database stats - simplified
105
- with gr.Accordion("System Info", open=False):
106
- stats_display = gr.Markdown(get_db_stats(vector_store))
107
-
108
- gr.Markdown(f"""
109
- **System Details:**
110
- - Version: {VERSION}
111
- - Embedding: {vector_store.embedding_engine.model_name}
112
- - Device: {vector_store.embedding_engine.device}
113
- """)
114
- refresh_button = gr.Button("Refresh", variant="secondary", size="sm")
115
 
116
- # Hidden element for search results (not visible to user)
117
- with gr.Accordion("Debug Information", open=False, visible=False):
118
- search_output = gr.Markdown()
119
 
120
- # Query history at the bottom (optional section)
121
- with gr.Accordion("Recent Questions", open=False):
122
- history_list = gr.Dataframe(
123
- headers=["Time", "Question", "Model"],
124
- datatype=["str", "str", "str"],
125
- row_count=5,
126
- col_count=(3, "fixed"),
127
- interactive=False
128
- )
129
-
130
- # Footer
131
- gr.Markdown(
132
- """<div class="footer">Document Knowledge Assistant helps you get insights from your documents using AI.
133
- Powered by Retrieval Augmented Generation.</div>"""
134
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
135
 
136
- # Query history storage
137
- query_history = []
 
 
138
 
139
- # Function to update API key based on selected model
140
- def update_api_key(api_key, model):
141
- if not api_key.strip():
142
- return "❌ API key cannot be empty"
143
-
144
- if model == "openai":
145
- success = rag_system.setup_openai(api_key)
146
- model_name = f"OpenAI {config.openai_model}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
147
  else:
148
- success = rag_system.setup_gemini(api_key)
149
- model_name = f"Google {config.gemini_model}"
 
 
 
 
 
150
 
151
- if success:
152
- return f" {model_name} connected successfully"
153
- else:
154
- return f"❌ Connection failed. Please check your API key and try again."
 
 
 
 
 
 
 
 
155
 
156
- # Query function that returns both response and search results
157
- def query_and_search(query, n_results, model, temperature, max_tokens):
158
- # Update configuration with current UI values
159
- config.temperature = float(temperature)
160
- config.max_tokens = int(max_tokens)
161
-
162
- start_time = datetime.now()
163
-
164
- if not query.strip():
165
- return "Please enter a question to get an answer.", "", query_history[-5:] if query_history else []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
166
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
167
  try:
168
- # Verify that API keys are configured
169
- if (model == "openai" and rag_system.openai_client is None) or \
170
- (model == "gemini" and not rag_system.gemini_configured):
171
- return "Please configure your API key first. Enter your API key in the settings panel and click 'Save API Key'.", "", query_history[-5:] if query_history else []
172
-
173
- # Call the RAG system's query and generate function
174
- response, search_output_text = rag_system.query_and_generate(
175
- query=query,
176
- n_results=int(n_results),
177
- model=model
178
- )
179
-
180
- # Add to history
181
- timestamp = datetime.now().strftime("%H:%M")
182
- query_history.append([timestamp, query, model])
183
-
184
- # Keep only the last 100 queries
185
- if len(query_history) > 100:
186
- query_history.pop(0)
187
-
188
- # Update the history display with the most recent entries (reverse chronological)
189
- recent_history = list(reversed(query_history[-5:])) if len(query_history) >= 5 else list(reversed(query_history))
190
-
191
- # Calculate elapsed time
192
- elapsed_time = (datetime.now() - start_time).total_seconds()
193
-
194
- # Add subtle timing information to the response
195
- response_with_timing = f"{response}\n\n<small>Answered in {elapsed_time:.1f}s</small>"
196
-
197
- return response_with_timing, search_output_text, recent_history
198
-
199
  except Exception as e:
200
- error_msg = f"Error processing query: {str(e)}"
201
- logger.error(error_msg)
202
- logger.error(traceback.format_exc())
203
- return "I encountered an error while processing your question. Please try again or check your API key settings.", "", query_history[-5:] if query_history else []
204
 
205
- # Function to clear the input and results
206
- def clear_inputs():
207
- return "", "", "", query_history[-5:] if query_history else []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
208
 
209
- # Set up events
210
- save_key_button.click(
211
- fn=update_api_key,
212
- inputs=[api_key_input, model_choice],
213
- outputs=api_status
214
- )
215
 
216
- query_button.click(
217
- fn=query_and_search,
218
- inputs=[query_input, num_results, model_choice, temperature_slider, max_tokens_slider],
219
- outputs=[response_output, search_output, history_list]
 
 
 
 
 
 
 
220
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
221
 
222
- refresh_button.click(
223
- fn=lambda: get_db_stats(vector_store),
224
- inputs=None,
225
- outputs=stats_display
226
- )
 
227
 
228
- clear_button.click(
229
- fn=clear_inputs,
230
- inputs=None,
231
- outputs=[query_input, response_output, search_output, history_list]
232
- )
233
 
234
- # Handle Enter key in query input
235
- query_input.submit(
236
- fn=query_and_search,
237
- inputs=[query_input, num_results, model_choice, temperature_slider, max_tokens_slider],
238
- outputs=[response_output, search_output, history_list]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
239
  )
240
 
241
- # Auto-fill examples
242
- examples = [
243
- ["What are the main features of this application?"],
244
- ["How does the retrieval augmented generation work?"],
245
- ["Can you explain the embedding models used in this system?"],
246
- ]
247
 
248
- gr.Examples(
249
- examples=examples,
250
- inputs=query_input,
251
- outputs=[response_output, search_output, history_list],
252
- fn=lambda q: query_and_search(q, num_results.value, model_choice.value, temperature_slider.value, max_tokens_slider.value),
253
- cache_examples=False,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
254
  )
255
 
256
- # Launch the interface with a nice theme
257
- app.launch(
258
- share=False, # Set to True to create a public link
259
- server_name="0.0.0.0", # Listen on all interfaces
260
- server_port=7860, # Default Gradio port
261
- debug=False, # Set to True during development
262
- auth=None, # Add (username, password) tuple for basic auth
263
- favicon_path="favicon.ico" if os.path.exists("favicon.ico") else None,
264
- show_error=True
265
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
266
  except Exception as e:
267
- logger.critical(f"Error starting application: {e}")
268
- print(f"Error starting application: {e}")
269
- sys.exit(1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
270
  try:
271
  logger.info(f"Loading document: {file_path}")
272
 
 
1
+ import os
2
+ import sys
3
+ import logging
4
+ from pathlib import Path
5
+ import json
6
+ from datetime import datetime
7
+ from typing import List, Dict, Any, Optional, Tuple, Union
8
+ import traceback
9
+
10
+ # Configure detailed logging with file output
11
+ LOG_DIR = "logs"
12
+ os.makedirs(LOG_DIR, exist_ok=True)
13
+ log_file = os.path.join(LOG_DIR, f"rag_system_{datetime.now().strftime('%Y%m%d_%H%M%S')}.log")
14
+
15
+ # Set up root logger with both file and console handlers
16
+ logging.basicConfig(
17
+ level=logging.INFO,
18
+ format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
19
+ handlers=[
20
+ logging.FileHandler(log_file),
21
+ logging.StreamHandler(sys.stdout)
22
+ ]
23
+ )
24
+ logger = logging.getLogger("rag_system")
25
+ logger.info(f"Starting RAG system. Log file: {log_file}")
26
+
27
+ # Importing necessary libraries with error handling
28
+ try:
29
+ import torch
30
+ import numpy as np
31
+ from sentence_transformers import SentenceTransformer
32
+ import chromadb
33
+ from chromadb.utils import embedding_functions
34
+ import gradio as gr
35
+ from openai import OpenAI
36
+ import google.generativeai as genai
37
+ logger.info("All required libraries successfully imported")
38
+ except ImportError as e:
39
+ logger.critical(f"Failed to import required libraries: {e}")
40
+ print(f"ERROR: Missing required libraries. Please install with: pip install -r requirements.txt")
41
+ print(f"Specific error: {e}")
42
+ sys.exit(1)
43
+
44
+ # Version info for tracking
45
+ VERSION = "1.1.0"
46
+ logger.info(f"RAG System Version: {VERSION}")
47
+
48
+ # Custom CSS for better UI
49
+ custom_css = """
50
+ .gradio-container {
51
+ max-width: 1200px;
52
+ margin: auto;
53
+ }
54
+ .gr-prose h1 {
55
+ font-size: 2.5rem;
56
+ margin-bottom: 1rem;
57
+ color: #1a5276;
58
+ }
59
+ .gr-prose h3 {
60
+ font-size: 1.25rem;
61
+ font-weight: 600;
62
+ margin-top: 1rem;
63
+ margin-bottom: 0.5rem;
64
+ color: #2874a6;
65
+ }
66
+ .container {
67
+ margin: 0 auto;
68
+ padding: 2rem;
69
+ }
70
+ .gr-box {
71
+ border-radius: 8px;
72
+ box-shadow: 0 1px 3px rgba(0,0,0,0.12), 0 1px 2px rgba(0,0,0,0.24);
73
+ padding: 1rem;
74
+ margin-bottom: 1rem;
75
+ background-color: #f9f9f9;
76
+ }
77
+ .footer {
78
+ text-align: center;
79
+ font-size: 0.8rem;
80
+ color: #666;
81
+ margin-top: 2rem;
82
+ }
83
+ """
84
+
85
+ class Config:
86
+ """
87
+ Configuration for vector store and RAG system.
88
 
89
+ This class centralizes all configuration parameters for the application,
90
+ making it easier to modify settings and ensure consistency.
91
+
92
+ Attributes:
93
+ local_dir (str): Directory for ChromaDB persistence
94
+ embedding_model (str): Name of the embedding model to use
95
+ collection_name (str): Name of the ChromaDB collection
96
+ default_top_k (int): Default number of results to return
97
+ openai_model (str): Default OpenAI model to use
98
+ gemini_model (str): Default Gemini model to use
99
+ temperature (float): Temperature setting for LLM generation
100
+ max_tokens (int): Maximum tokens for LLM response
101
+ system_name (str): Name of the system for UI
102
+ context_limit (int): Maximum characters to include in context
103
+ """
104
+
105
+ def __init__(self,
106
+ local_dir: str = "./chroma_db",
107
+ embedding_model: str = "all-MiniLM-L6-v2",
108
+ collection_name: str = "markdown_docs",
109
+ default_top_k: int = 8, # Increased from 5 to 8 for more context
110
+ openai_model: str = "gpt-4o-mini",
111
+ gemini_model: str = "gemini-1.5-flash",
112
+ temperature: float = 0.3,
113
+ max_tokens: int = 2000, # Increased from 1000 to 2000 for more comprehensive responses
114
+ system_name: str = "Document Knowledge Assistant",
115
+ context_limit: int = 16000): # Increased context limit for more comprehensive context
116
+ self.local_dir = local_dir
117
+ self.embedding_model = embedding_model
118
+ self.collection_name = collection_name
119
+ self.default_top_k = default_top_k
120
+ self.openai_model = openai_model
121
+ self.gemini_model = gemini_model
122
+ self.temperature = temperature
123
+ self.max_tokens = max_tokens
124
+ self.system_name = system_name
125
+ self.context_limit = context_limit
126
 
127
+ # Create local directory if it doesn't exist
128
+ os.makedirs(local_dir, exist_ok=True)
129
 
130
+ logger.info(f"Initialized configuration: {self.__dict__}")
131
+
132
+ def to_dict(self) -> Dict[str, Any]:
133
+ """Convert configuration to dictionary for serialization"""
134
+ return self.__dict__
135
+
136
+ @classmethod
137
+ def from_file(cls, config_path: str) -> 'Config':
138
+ """Load configuration from JSON file"""
139
+ try:
140
+ with open(config_path, 'r') as f:
141
+ config_dict = json.load(f)
142
+ logger.info(f"Loaded configuration from {config_path}")
143
+ return cls(**config_dict)
144
+ except Exception as e:
145
+ logger.error(f"Failed to load configuration from {config_path}: {e}")
146
+ logger.info("Using default configuration")
147
+ return cls()
148
+
149
+ def save_to_file(self, config_path: str) -> bool:
150
+ """Save configuration to JSON file"""
151
+ try:
152
+ with open(config_path, 'w') as f:
153
+ json.dump(self.to_dict(), f, indent=2)
154
+ logger.info(f"Saved configuration to {config_path}")
155
+ return True
156
+ except Exception as e:
157
+ logger.error(f"Failed to save configuration to {config_path}: {e}")
158
+ return False
159
+
160
+ class EmbeddingEngine:
161
+ """
162
+ Handle embeddings with a lightweight model.
163
+
164
+ This class manages the embedding model used to convert text to vector
165
+ representations for semantic search.
166
+
167
+ Attributes:
168
+ model (SentenceTransformer): The loaded embedding model
169
+ model_name (str): Name of the successfully loaded model
170
+ vector_size (int): Dimension of the embedding vectors
171
+ device (str): Device used for inference ('cuda' or 'cpu')
172
+ """
173
+
174
+ def __init__(self, model_name="all-MiniLM-L6-v2"):
175
+ """
176
+ Initialize the embedding engine with the specified model.
177
 
178
+ Args:
179
+ model_name (str): Name of the embedding model to load
 
 
180
 
181
+ Raises:
182
+ SystemExit: If no embedding model could be loaded
183
+ """
184
+ # Use GPU if available
185
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
186
+ logger.info(f"Using device for embeddings: {self.device}")
187
+
188
+ # Try multiple model options in order of preference
189
+ model_options = [
190
+ model_name,
191
+ "all-MiniLM-L6-v2", # Good balance of speed and quality
192
+ "paraphrase-MiniLM-L3-v2", # Faster but less accurate
193
+ "all-mpnet-base-v2" # Higher quality but larger model
194
+ ]
195
+
196
+ self.model = None
197
+
198
+ # Try each model in order until one works
199
+ for model_option in model_options:
200
+ try:
201
+ logger.info(f"Attempting to load embedding model: {model_option}")
202
+ self.model = SentenceTransformer(model_option)
203
 
204
+ # Move model to device
205
+ self.model.to(self.device)
206
+
207
+ logger.info(f"Successfully loaded embedding model: {model_option}")
208
+ self.model_name = model_option
209
+ self.vector_size = self.model.get_sentence_embedding_dimension()
210
+ logger.info(f"Embedding vector size: {self.vector_size}")
211
+ break
212
+
213
+ except Exception as e:
214
+ logger.warning(f"Failed to load embedding model {model_option}: {str(e)}")
215
+
216
+ if self.model is None:
217
+ error_msg = "Failed to load any embedding model. Please check your internet connection or install models locally."
218
+ logger.critical(error_msg)
219
+ raise SystemExit(error_msg)
220
+
221
+ def embed(self, texts: List[str]) -> np.ndarray:
222
+ """
223
+ Generate embeddings for a list of texts.
224
+
225
+ Args:
226
+ texts (List[str]): List of texts to embed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
227
 
228
+ Returns:
229
+ np.ndarray: Array of embeddings
 
230
 
231
+ Raises:
232
+ ValueError: If the input is invalid
233
+ RuntimeError: If embedding fails
234
+ """
235
+ if not texts:
236
+ raise ValueError("Cannot embed empty list of texts")
237
+
238
+ try:
239
+ embeddings = self.model.encode(texts, convert_to_numpy=True)
240
+ return embeddings
241
+ except Exception as e:
242
+ logger.error(f"Error generating embeddings: {e}")
243
+ raise RuntimeError(f"Failed to generate embeddings: {e}")
244
+
245
+ class VectorStoreManager:
246
+ """
247
+ Manage Chroma vector store operations - upload, query, etc.
248
+
249
+ This class provides an interface to the ChromaDB vector database,
250
+ handling document storage, retrieval, and management.
251
+
252
+ Attributes:
253
+ config (Config): Configuration parameters
254
+ client (chromadb.PersistentClient): ChromaDB client
255
+ collection (chromadb.Collection): The active ChromaDB collection
256
+ embedding_engine (EmbeddingEngine): Engine for generating embeddings
257
+ """
258
+
259
+ def __init__(self, config: Config):
260
+ """
261
+ Initialize the vector store manager.
262
+
263
+ Args:
264
+ config (Config): Configuration parameters
265
 
266
+ Raises:
267
+ SystemExit: If the vector store cannot be initialized
268
+ """
269
+ self.config = config
270
 
271
+ # Initialize Chroma client (local persistence)
272
+ logger.info(f"Initializing Chroma at {config.local_dir}")
273
+ try:
274
+ self.client = chromadb.PersistentClient(path=config.local_dir)
275
+ logger.info("ChromaDB client initialized successfully")
276
+ except Exception as e:
277
+ error_msg = f"Failed to initialize ChromaDB client: {e}"
278
+ logger.critical(error_msg)
279
+ raise SystemExit(error_msg)
280
+
281
+ # Get or create collection
282
+ try:
283
+ # Initialize embedding model
284
+ logger.info("Loading embedding model...")
285
+ self.embedding_engine = EmbeddingEngine(config.embedding_model)
286
+ logger.info(f"Using embedding model: {self.embedding_engine.model_name}")
287
+
288
+ # Create embedding function
289
+ sentence_transformer_ef = embedding_functions.SentenceTransformerEmbeddingFunction(
290
+ model_name=self.embedding_engine.model_name
291
+ )
292
+
293
+ # Try to get existing collection or create a new one
294
+ try:
295
+ self.collection = self.client.get_collection(
296
+ name=config.collection_name,
297
+ embedding_function=sentence_transformer_ef
298
+ )
299
+ logger.info(f"Using existing collection: {config.collection_name}")
300
+ except Exception as e:
301
+ logger.warning(f"Error getting collection: {e}")
302
+ # Attempt to get a list of available collections
303
+ collections = self.client.list_collections()
304
+ if collections:
305
+ logger.info(f"Available collections: {[c.name for c in collections]}")
306
+ # Use the first available collection if any
307
+ self.collection = self.client.get_collection(
308
+ name=collections[0].name,
309
+ embedding_function=sentence_transformer_ef
310
+ )
311
+ logger.info(f"Using collection: {collections[0].name}")
312
  else:
313
+ # Create new collection if none exist
314
+ self.collection = self.client.create_collection(
315
+ name=config.collection_name,
316
+ embedding_function=sentence_transformer_ef,
317
+ metadata={"hnsw:space": "cosine"}
318
+ )
319
+ logger.info(f"Created new collection: {config.collection_name}")
320
 
321
+ except Exception as e:
322
+ error_msg = f"Error initializing Chroma collection: {e}"
323
+ logger.critical(error_msg)
324
+ raise SystemExit(error_msg)
325
+
326
+ def query(self, query_text: str, n_results: int = 5) -> List[Dict]:
327
+ """
328
+ Query the vector store with a text query.
329
+
330
+ Args:
331
+ query_text (str): The query text
332
+ n_results (int): Number of results to return
333
 
334
+ Returns:
335
+ List[Dict]: List of results with document text, metadata, and similarity score
336
+ """
337
+ if not query_text.strip():
338
+ logger.warning("Empty query received")
339
+ return []
340
+
341
+ try:
342
+ logger.info(f"Querying vector store with: '{query_text[:50]}...' (top {n_results})")
343
+
344
+ # Query the collection
345
+ search_results = self.collection.query(
346
+ query_texts=[query_text],
347
+ n_results=n_results,
348
+ include=["documents", "metadatas", "distances"]
349
+ )
350
+
351
+ # Format results
352
+ results = []
353
+ if search_results["documents"] and len(search_results["documents"][0]) > 0:
354
+ for i in range(len(search_results["documents"][0])):
355
+ results.append({
356
+ 'document': search_results["documents"][0][i],
357
+ 'metadata': search_results["metadatas"][0][i] if search_results["metadatas"] else {},
358
+ 'score': 1.0 - search_results["distances"][0][i], # Convert distance to similarity
359
+ 'distance': search_results["distances"][0][i]
360
+ })
361
 
362
+ logger.info(f"Found {len(results)} results for query")
363
+ else:
364
+ logger.info("No results found for query")
365
+
366
+ return results
367
+ except Exception as e:
368
+ logger.error(f"Error querying collection: {e}")
369
+ logger.debug(traceback.format_exc())
370
+ return []
371
+
372
+ def add_document(self,
373
+ document: str,
374
+ doc_id: str,
375
+ metadata: Dict[str, Any]) -> bool:
376
+ """
377
+ Add a document to the vector store.
378
+
379
+ Args:
380
+ document (str): The document text
381
+ doc_id (str): Unique identifier for the document
382
+ metadata (Dict[str, Any]): Metadata about the document
383
+
384
+ Returns:
385
+ bool: True if successful, False otherwise
386
+ """
387
+ try:
388
+ logger.info(f"Adding document '{doc_id}' to vector store")
389
+
390
+ # Add the document to the collection
391
+ self.collection.add(
392
+ documents=[document],
393
+ ids=[doc_id],
394
+ metadatas=[metadata]
395
+ )
396
+
397
+ logger.info(f"Successfully added document '{doc_id}'")
398
+ return True
399
+ except Exception as e:
400
+ logger.error(f"Error adding document to collection: {e}")
401
+ return False
402
+
403
+ def delete_document(self, doc_id: str) -> bool:
404
+ """
405
+ Delete a document from the vector store.
406
+
407
+ Args:
408
+ doc_id (str): ID of the document to delete
409
+
410
+ Returns:
411
+ bool: True if successful, False otherwise
412
+ """
413
+ try:
414
+ logger.info(f"Deleting document '{doc_id}' from vector store")
415
+ self.collection.delete(ids=[doc_id])
416
+ logger.info(f"Successfully deleted document '{doc_id}'")
417
+ return True
418
+ except Exception as e:
419
+ logger.error(f"Error deleting document from collection: {e}")
420
+ return False
421
+
422
+ def get_statistics(self) -> Dict[str, Any]:
423
+ """
424
+ Get statistics about the vector store.
425
+
426
+ Returns:
427
+ Dict[str, Any]: Statistics about the vector store
428
+ """
429
+ stats = {
430
+ 'collection_name': self.config.collection_name,
431
+ 'embedding_model': self.embedding_engine.model_name,
432
+ 'embedding_dimensions': self.embedding_engine.vector_size,
433
+ 'device': self.embedding_engine.device
434
+ }
435
+
436
+ try:
437
+ # Get collection count
438
+ collection_count = self.collection.count()
439
+ stats['total_documents'] = collection_count
440
+
441
+ # Get unique metadata values
442
+ if collection_count > 0:
443
  try:
444
+ # Get a sample of document metadata
445
+ sample_results = self.collection.get(limit=min(collection_count, 100))
446
+ if sample_results and 'metadatas' in sample_results and sample_results['metadatas']:
447
+ # Count unique files if filename exists in metadata
448
+ filenames = set()
449
+ for metadata in sample_results['metadatas']:
450
+ if 'filename' in metadata:
451
+ filenames.add(metadata['filename'])
452
+ stats['unique_files'] = len(filenames)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
453
  except Exception as e:
454
+ logger.warning(f"Error getting metadata statistics: {e}")
 
 
 
455
 
456
+ logger.info(f"Vector store statistics: {stats}")
457
+ except Exception as e:
458
+ logger.error(f"Error getting statistics: {e}")
459
+ stats['error'] = str(e)
460
+
461
+ return stats
462
+
463
+ class RAGSystem:
464
+ """
465
+ Retrieval-Augmented Generation with multiple LLM providers.
466
+
467
+ This class handles the RAG workflow: retrieval of relevant documents,
468
+ formatting context, and generating responses with different LLM providers.
469
+
470
+ Attributes:
471
+ vector_store (VectorStoreManager): Manager for vector store operations
472
+ openai_client (Optional[OpenAI]): OpenAI client
473
+ gemini_configured (bool): Whether Gemini API is configured
474
+ config (Config): Configuration parameters
475
+ """
476
+
477
+ def __init__(self, vector_store: VectorStoreManager, config: Config):
478
+ """
479
+ Initialize the RAG system.
480
+
481
+ Args:
482
+ vector_store (VectorStoreManager): Vector store manager
483
+ config (Config): Configuration parameters
484
+ """
485
+ self.vector_store = vector_store
486
+ self.config = config
487
+ self.openai_client = None
488
+ self.gemini_configured = False
489
+
490
+ logger.info("Initialized RAG system")
491
+
492
+ def setup_openai(self, api_key: str) -> bool:
493
+ """
494
+ Set up OpenAI client with API key.
495
+
496
+ Args:
497
+ api_key (str): OpenAI API key
498
 
499
+ Returns:
500
+ bool: True if successful, False otherwise
501
+ """
502
+ if not api_key.strip():
503
+ logger.warning("Empty OpenAI API key provided")
504
+ return False
505
 
506
+ try:
507
+ logger.info("Setting up OpenAI client")
508
+ self.openai_client = OpenAI(api_key=api_key)
509
+ # Test the API key with a simple request
510
+ response = self.openai_client.chat.completions.create(
511
+ model=self.config.openai_model,
512
+ messages=[
513
+ {"role": "system", "content": "You are a helpful assistant."},
514
+ {"role": "user", "content": "Test connection"}
515
+ ],
516
+ max_tokens=10
517
  )
518
+ logger.info("OpenAI client configured successfully")
519
+ return True
520
+ except Exception as e:
521
+ logger.error(f"Error initializing OpenAI client: {e}")
522
+ self.openai_client = None
523
+ return False
524
+
525
+ def setup_gemini(self, api_key: str) -> bool:
526
+ """
527
+ Set up Gemini with API key.
528
+
529
+ Args:
530
+ api_key (str): Google AI API key
531
 
532
+ Returns:
533
+ bool: True if successful, False otherwise
534
+ """
535
+ if not api_key.strip():
536
+ logger.warning("Empty Gemini API key provided")
537
+ return False
538
 
539
+ try:
540
+ logger.info("Setting up Gemini client")
541
+ genai.configure(api_key=api_key)
 
 
542
 
543
+ # Test the API key with a simple request
544
+ model = genai.GenerativeModel(self.config.gemini_model)
545
+ response = model.generate_content("Test connection")
546
+
547
+ self.gemini_configured = True
548
+ logger.info("Gemini client configured successfully")
549
+ return True
550
+ except Exception as e:
551
+ logger.error(f"Error configuring Gemini: {e}")
552
+ self.gemini_configured = False
553
+ return False
554
+
555
+ def format_context(self, documents: List[Dict]) -> str:
556
+ """
557
+ Format retrieved documents into context for the LLM.
558
+
559
+ Args:
560
+ documents (List[Dict]): List of retrieved documents
561
+
562
+ Returns:
563
+ str: Formatted context for the LLM
564
+ """
565
+ if not documents:
566
+ logger.warning("No documents provided for context formatting")
567
+ return "No relevant documents found."
568
+
569
+ logger.info(f"Formatting {len(documents)} documents for context")
570
+ context_parts = []
571
+
572
+ for i, doc in enumerate(documents):
573
+ metadata = doc['metadata']
574
+ # Extract document metadata in a robust way
575
+ title = metadata.get('title', metadata.get('filename', 'Unknown document'))
576
+
577
+ # Format header with just essential metadata for cleaner context
578
+ header = f"Document {i+1} - {title}"
579
+
580
+ # For readability, limit length of context document (using config value)
581
+ doc_text = doc['document']
582
+ if len(doc_text) > (self.config.context_limit // len(documents)):
583
+ # Divide context limit among the documents
584
+ max_length = self.config.context_limit // len(documents)
585
+ doc_text = doc_text[:max_length] + "... [Document truncated for brevity]"
586
+
587
+ context_parts.append(f"{header}:\n{doc_text}\n")
588
+
589
+ full_context = "\n".join(context_parts)
590
+ logger.info(f"Created context with {len(full_context)} characters")
591
+
592
+ return full_context
593
+
594
+ def generate_response_openai(self, query: str, context: str) -> str:
595
+ """
596
+ Generate a response using OpenAI model with context.
597
+
598
+ Args:
599
+ query (str): User query
600
+ context (str): Formatted document context
601
+
602
+ Returns:
603
+ str: Generated response
604
+ """
605
+ if not self.openai_client:
606
+ logger.warning("OpenAI API key not configured for response generation")
607
+ return "Please configure an OpenAI API key to use this feature. Enter your API key in the field and click 'Save API Key'."
608
+
609
+ # Improved system prompt for better, more comprehensive responses
610
+ system_prompt = """
611
+ You are an exceptionally helpful, clear, and friendly AI research assistant. Your goal is to provide comprehensive, well-structured, and insightful answers based on the provided document context.
612
+
613
+ Guidelines for your response:
614
+
615
+ 1. USE ONLY the information contained in the provided context documents to form your answer. If the context doesn't contain enough information to provide a complete answer, acknowledge this limitation clearly.
616
+
617
+ 2. Always provide well-structured, detailed responses between 300-500 words that thoroughly address the user's question.
618
+
619
+ 3. Format your response with clear headings, bullet points, or numbered lists when appropriate to enhance readability.
620
+
621
+ 4. Cite your sources by referring to the document numbers (e.g., "According to Document 1...") to support your claims.
622
+
623
+ 5. Use a friendly, conversational, and supportive tone that makes complex information accessible.
624
+
625
+ 6. If different documents offer conflicting information, acknowledge these differences and present both perspectives without bias.
626
+
627
+ 7. When appropriate, organize information into logical categories or chronological order to improve clarity.
628
+
629
+ 8. Use examples from the documents to illustrate key points when available.
630
+
631
+ 9. Conclude with a brief summary of the main points if the answer is complex.
632
+
633
+ 10. Remember to stay focused on the user's specific question while providing sufficient context for complete understanding.
634
+ """
635
+
636
+ try:
637
+ logger.info(f"Generating response with OpenAI ({self.config.openai_model})")
638
+
639
+ start_time = datetime.now()
640
+ response = self.openai_client.chat.completions.create(
641
+ model=self.config.openai_model,
642
+ messages=[
643
+ {"role": "system", "content": system_prompt},
644
+ {"role": "user", "content": f"Context:\n{context}\n\nQuestion: {query}"}
645
+ ],
646
+ temperature=self.config.temperature,
647
+ max_tokens=self.config.max_tokens,
648
  )
649
 
650
+ generation_time = (datetime.now() - start_time).total_seconds()
651
+ response_text = response.choices[0].message.content
 
 
 
 
652
 
653
+ logger.info(f"Generated response with OpenAI in {generation_time:.2f} seconds")
654
+ return response_text
655
+ except Exception as e:
656
+ error_msg = f"Error generating response with OpenAI: {str(e)}"
657
+ logger.error(error_msg)
658
+ return f"I encountered an error while generating your response. Please try again or check your API key. Error details: {str(e)}"
659
+
660
+ def generate_response_gemini(self, query: str, context: str) -> str:
661
+ """
662
+ Generate a response using Gemini with context.
663
+
664
+ Args:
665
+ query (str): User query
666
+ context (str): Formatted document context
667
+
668
+ Returns:
669
+ str: Generated response
670
+ """
671
+ if not self.gemini_configured:
672
+ logger.warning("Gemini API key not configured for response generation")
673
+ return "Please configure a Google AI API key to use this feature. Enter your API key in the field and click 'Save API Key'."
674
+
675
+ # Improved Gemini prompt for more comprehensive and user-friendly responses
676
+ prompt = f"""
677
+ You are a knowledgeable and friendly research assistant who excels at providing clear, comprehensive, and well-structured responses. Your goal is to help users understand complex information from documents in an accessible way.
678
+
679
+ **Guidelines for Your Response:**
680
+
681
+ - Create a detailed, well-organized response of approximately 300-500 words that thoroughly addresses the user's question.
682
+ - Use ONLY information from the provided context documents.
683
+ - Structure your answer with clear paragraphs, and use headings, bullet points, or numbered lists when appropriate.
684
+ - Maintain a friendly, conversational tone that makes information accessible and engaging.
685
+ - When citing information, reference specific documents by number (e.g., "As mentioned in Document 2...").
686
+ - If the context doesn't contain enough information for a complete answer, acknowledge this limitation while providing what you can from the available context.
687
+ - If documents contain conflicting information, present both perspectives fairly.
688
+ - Conclude with a brief summary if the topic is complex.
689
+
690
+ **Context Documents:**
691
+ {context}
692
+
693
+ **User's Question:**
694
+ {query}
695
+
696
+ **Your Response:**
697
+ """
698
+
699
+ try:
700
+ logger.info(f"Generating response with Gemini ({self.config.gemini_model})")
701
+
702
+ start_time = datetime.now()
703
+ model = genai.GenerativeModel(self.config.gemini_model)
704
+
705
+ generation_config = {
706
+ "temperature": self.config.temperature,
707
+ "max_output_tokens": self.config.max_tokens,
708
+ "top_p": 0.9,
709
+ "top_k": 40
710
+ }
711
+
712
+ response = model.generate_content(
713
+ prompt,
714
+ generation_config=generation_config
715
  )
716
 
717
+ generation_time = (datetime.now() - start_time).total_seconds()
718
+ response_text = response.text
719
+
720
+ logger.info(f"Generated response with Gemini in {generation_time:.2f} seconds")
721
+ return response_text
722
+ except Exception as e:
723
+ error_msg = f"Error generating response with Gemini: {str(e)}"
724
+ logger.error(error_msg)
725
+ return f"I encountered an error while generating your response. Please try again or check your API key. Error details: {str(e)}"
726
+
727
+ def query_and_generate(self,
728
+ query: str,
729
+ n_results: int = 5,
730
+ model: str = "openai") -> Tuple[str, str]:
731
+ """
732
+ Retrieve relevant documents and generate a response using the specified model.
733
+
734
+ Args:
735
+ query (str): User query
736
+ n_results (int): Number of documents to retrieve
737
+ model (str): Model provider to use ('openai' or 'gemini')
738
+
739
+ Returns:
740
+ Tuple[str, str]: (Generated response, Search results)
741
+ """
742
+ if not query.strip():
743
+ logger.warning("Empty query received")
744
+ return "Please enter a question to get a response.", "No search performed."
745
+
746
+ logger.info(f"Processing query: '{query[:50]}...' with {model} model")
747
+
748
+ # Query vector store
749
+ documents = self.vector_store.query(query, n_results=n_results)
750
+
751
+ # Format search results (for logs and hidden UI component)
752
+ # We'll format this in a way that's more useful for reference but not shown in UI
753
+ formatted_results = []
754
+ for i, res in enumerate(documents):
755
+ metadata = res['metadata']
756
+ title = metadata.get('title', metadata.get('filename', 'Unknown'))
757
+ score = res['score']
758
+
759
+ # Only include a very brief preview for reference
760
+ preview = res['document'][:100] + '...' if len(res['document']) > 100 else res['document']
761
+ formatted_results.append(f"Document {i+1}: {title} (Relevance: {score:.2f})")
762
+
763
+ search_output_text = "\n".join(formatted_results) if formatted_results else "No relevant documents found."
764
+
765
+ if not documents:
766
+ logger.warning("No relevant documents found")
767
+ return "I couldn't find relevant information in the knowledge base to answer your question. Could you try rephrasing your question or ask about a different topic?", search_output_text
768
+
769
+ # Format context
770
+ context = self.format_context(documents)
771
+
772
+ # Generate response with the appropriate model
773
+ if model == "openai":
774
+ response = self.generate_response_openai(query, context)
775
+ elif model == "gemini":
776
+ response = self.generate_response_gemini(query, context)
777
+ else:
778
+ error_msg = f"Unknown model: {model}"
779
+ logger.error(error_msg)
780
+ return error_msg, search_output_text
781
+
782
+ return response, search_output_text
783
+
784
+ def get_db_stats(vector_store: VectorStoreManager) -> str:
785
+ """
786
+ Function to get vector store statistics.
787
+
788
+ Args:
789
+ vector_store (VectorStoreManager): Vector store manager
790
+
791
+ Returns:
792
+ str: Formatted statistics string
793
+ """
794
+ try:
795
+ stats = vector_store.get_statistics()
796
+ total_docs = stats.get('total_documents', 0)
797
+
798
+ stats_text = f"Documents in knowledge base: {total_docs}"
799
+ return stats_text
800
  except Exception as e:
801
+ logger.error(f"Error getting statistics: {e}")
802
+ return "Error getting database statistics"
803
+
804
+ # Helper function for loading documents (can be expanded in future versions)
805
+ def load_document(file_path: str, chunk_size: int = 2000, chunk_overlap: int = 200) -> bool:
806
+ """
807
+ Load a document into the vector store.
808
+
809
+ Args:
810
+ file_path (str): Path to the document
811
+ chunk_size (int): Size of chunks to split the document into
812
+ chunk_overlap (int): Overlap between chunks
813
+
814
+ Returns:
815
+ bool: True if successful, False otherwise
816
+ """
817
+ try:
818
  try:
819
  logger.info(f"Loading document: {file_path}")
820