Spaces:
Sleeping
Sleeping
import os | |
import sys | |
import logging | |
from pathlib import Path | |
import json | |
from datetime import datetime | |
from typing import List, Dict, Any, Optional | |
# Configure logging | |
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') | |
logger = logging.getLogger(__name__) | |
# Importing necessary libraries | |
import torch | |
import numpy as np | |
from sentence_transformers import SentenceTransformer | |
import chromadb | |
from chromadb.utils import embedding_functions | |
import gradio as gr | |
from openai import OpenAI | |
import google.generativeai as genai | |
# Configuration class | |
class Config: | |
"""Configuration for vector store and RAG""" | |
def __init__(self, | |
local_dir: str = ".", | |
embedding_model: str = "all-MiniLM-L6-v2", | |
collection_name: str = "markdown_docs"): | |
self.local_dir = local_dir | |
self.embedding_model = embedding_model | |
self.collection_name = collection_name | |
# Embedding engine | |
class EmbeddingEngine: | |
"""Handle embeddings with a lightweight model""" | |
def __init__(self, model_name="all-MiniLM-L6-v2"): | |
# Use GPU if available | |
self.device = "cuda" if torch.cuda.is_available() else "cpu" | |
logger.info(f"Using device: {self.device}") | |
# Try multiple model options in order of preference | |
model_options = [ | |
model_name, | |
"all-MiniLM-L6-v2", | |
"paraphrase-MiniLM-L3-v2", | |
"all-mpnet-base-v2" # Higher quality but larger model | |
] | |
self.model = None | |
# Try each model in order until one works | |
for model_option in model_options: | |
try: | |
logger.info(f"Attempting to load model: {model_option}") | |
self.model = SentenceTransformer(model_option) | |
# Move model to device | |
self.model.to(self.device) | |
logger.info(f"Successfully loaded model: {model_option}") | |
self.model_name = model_option | |
self.vector_size = self.model.get_sentence_embedding_dimension() | |
break | |
except Exception as e: | |
logger.warning(f"Failed to load model {model_option}: {str(e)}") | |
if self.model is None: | |
logger.error("Failed to load any embedding model. Exiting.") | |
sys.exit(1) | |
class VectorStoreManager: | |
"""Manage Chroma vector store operations - upload, query, etc.""" | |
def __init__(self, config: Config): | |
self.config = config | |
# Initialize Chroma client (local persistence) | |
logger.info(f"Initializing Chroma at {config.local_dir}") | |
self.client = chromadb.PersistentClient(path=config.local_dir) | |
# Get or create collection | |
try: | |
# Initialize embedding model | |
logger.info("Loading embedding model...") | |
self.embedding_engine = EmbeddingEngine(config.embedding_model) | |
logger.info(f"Using model: {self.embedding_engine.model_name}") | |
# Create embedding function | |
sentence_transformer_ef = embedding_functions.SentenceTransformerEmbeddingFunction( | |
model_name=self.embedding_engine.model_name | |
) | |
# Try to get existing collection | |
try: | |
self.collection = self.client.get_collection( | |
name=config.collection_name, | |
embedding_function=sentence_transformer_ef | |
) | |
logger.info(f"Using existing collection: {config.collection_name}") | |
except Exception as e: | |
logger.error(f"Error getting collection: {e}") | |
# Attempt to get a list of available collections | |
collections = self.client.list_collections() | |
if collections: | |
logger.info(f"Available collections: {[c.name for c in collections]}") | |
# Use the first available collection if any | |
self.collection = self.client.get_collection( | |
name=collections[0].name, | |
embedding_function=sentence_transformer_ef | |
) | |
logger.info(f"Using collection: {collections[0].name}") | |
else: | |
# Create new collection if none exist | |
self.collection = self.client.create_collection( | |
name=config.collection_name, | |
embedding_function=sentence_transformer_ef, | |
metadata={"hnsw:space": "cosine"} | |
) | |
logger.info(f"Created new collection: {config.collection_name}") | |
except Exception as e: | |
logger.error(f"Error initializing Chroma collection: {e}") | |
sys.exit(1) | |
def query(self, query_text: str, n_results: int = 5) -> List[Dict]: | |
""" | |
Query the vector store with a text query | |
""" | |
try: | |
# Query the collection | |
search_results = self.collection.query( | |
query_texts=[query_text], | |
n_results=n_results, | |
include=["documents", "metadatas", "distances"] | |
) | |
# Format results | |
results = [] | |
if search_results["documents"] and len(search_results["documents"][0]) > 0: | |
for i in range(len(search_results["documents"][0])): | |
results.append({ | |
'document': search_results["documents"][0][i], | |
'metadata': search_results["metadatas"][0][i], | |
'score': 1.0 - search_results["distances"][0][i] # Convert distance to similarity | |
}) | |
return results | |
except Exception as e: | |
logger.error(f"Error querying collection: {e}") | |
return [] | |
def get_statistics(self) -> Dict[str, Any]: | |
"""Get statistics about the vector store""" | |
stats = {} | |
try: | |
# Get collection count | |
collection_info = self.collection.count() | |
stats['total_documents'] = collection_info | |
# Estimate unique files - with no chunking, each document is a file | |
stats['unique_files'] = collection_info | |
except Exception as e: | |
logger.error(f"Error getting statistics: {e}") | |
stats['error'] = str(e) | |
return stats | |
class RAGSystem: | |
"""Retrieval-Augmented Generation with multiple LLM providers""" | |
def __init__(self, vector_store: VectorStoreManager): | |
self.vector_store = vector_store | |
self.openai_client = None | |
self.gemini_configured = False | |
def setup_openai(self, api_key: str): | |
"""Set up OpenAI client with API key""" | |
try: | |
self.openai_client = OpenAI(api_key=api_key) | |
return True | |
except Exception as e: | |
logger.error(f"Error initializing OpenAI client: {e}") | |
return False | |
def setup_gemini(self, api_key: str): | |
"""Set up Gemini with API key""" | |
try: | |
genai.configure(api_key=api_key) | |
self.gemini_configured = True | |
return True | |
except Exception as e: | |
logger.error(f"Error configuring Gemini: {e}") | |
return False | |
def format_context(self, documents: List[Dict]) -> str: | |
"""Format retrieved documents into context for the LLM""" | |
if not documents: | |
return "No relevant documents found." | |
context_parts = [] | |
for i, doc in enumerate(documents): | |
metadata = doc['metadata'] | |
title = metadata.get('title', metadata.get('filename', 'Unknown document')) | |
# For readability, limit length of context document | |
doc_text = doc['document'] | |
if len(doc_text) > 10000: # Limit long documents in context | |
doc_text = doc_text[:10000] + "... [Document truncated for context]" | |
context_parts.append(f"Document {i+1} - {title}:\n{doc_text}\n") | |
return "\n".join(context_parts) | |
def generate_response_openai(self, query: str, context: str) -> str: | |
"""Generate a response using OpenAI model with context""" | |
if not self.openai_client: | |
return "Error: OpenAI API key not configured. Please enter an API key in the API key field." | |
system_prompt = """ | |
You are a helpful assistant that answers questions based on the context provided. | |
Use the information from the context to answer the user's question. | |
If the context doesn't contain the information needed, say so clearly. | |
Always cite the specific sections from the context that you used in your answer. | |
""" | |
try: | |
response = self.openai_client.chat.completions.create( | |
model="gpt-4o-mini", # Use GPT-4o mini | |
messages=[ | |
{"role": "system", "content": system_prompt}, | |
{"role": "user", "content": f"Context:\n{context}\n\nQuestion: {query}"} | |
], | |
temperature=0.3, # Lower temperature for more factual responses | |
max_tokens=5000, | |
) | |
return response.choices[0].message.content | |
except Exception as e: | |
logger.error(f"Error generating response with OpenAI: {e}") | |
return f"Error generating response with OpenAI: {str(e)}" | |
def generate_response_gemini(self, query: str, context: str) -> str: | |
"""Generate a response using Gemini with context""" | |
if not self.gemini_configured: | |
return "Error: Google AI API key not configured. Please enter an API key in the API key field." | |
prompt = f""" | |
<prompt> | |
<system> | |
<name>Loss Dog</name> | |
<role>You are a highly intelligent AI specializing in labor market analysis, job trends, and skillset forecasting. You utilize a combination of structured data from sources like the Bureau of Labor Statistics (BLS) and the World Economic Forum (WEF), alongside advanced retrieval-augmented generation (RAG) techniques.</role> | |
<goal>Your mission is to provide insightful, data-driven, and comprehensive answers to users seeking career and job market intelligence. You must ensure clarity, depth, and practical relevance in all responses.</goal> | |
<personality> | |
<tone>Friendly, professional, and engaging</tone> | |
<depth>Detailed, nuanced, and well-explained</depth> | |
<clarity>Well-structured with headings, citations, and easy-to-follow breakdowns</clarity> | |
</personality> | |
<methodology> | |
<data_sources> | |
<source>Bureau of Labor Statistics (BLS)</source> | |
<source>World Economic Forum (WEF) reports</source> | |
<source>Market research studies</source> | |
<source>Industry whitepapers</source> | |
<source>Company hiring trends</source> | |
</data_sources> | |
<reasoning_strategy> | |
<if_data_available> | |
<response> | |
Use precise statistics, industry insights, and expert analyses from retrieved sources to craft an evidence-based answer. | |
</response> | |
</if_data_available> | |
<if_data_unavailable> | |
<response> | |
Clearly state that the exact data is unavailable. However, provide a **comprehensive explanation** using logical deduction, adjacent industry trends, historical patterns, and economic principles. | |
</response> | |
</if_data_unavailable> | |
</reasoning_strategy> | |
<output_expectations> | |
<length>100-500 words, depending on complexity and sources available</length> | |
<structure> | |
<section>Introduction (sets context and purpose)</section> | |
<section>Data-backed analysis (citing retrieved sources)</section> | |
<section>Logical deduction and reasoning (when necessary)</section> | |
<section>Conclusion (summarizes insights and provides actionable takeaways)</section> | |
</structure> | |
<citation_style>Clearly cite data sources within the response (e.g., "According to BLS 2024 report...").</citation_style> | |
<engagement>Encourage follow-up questions and deeper exploration where relevant.</engagement> | |
</output_expectations> | |
</methodology> | |
</system> | |
Context: | |
{context} | |
Question: {query} | |
""" | |
try: | |
model = genai.GenerativeModel('gemini-1.5-flash') | |
response = model.generate_content(prompt) | |
return response.text | |
except Exception as e: | |
logger.error(f"Error generating response with Gemini: {e}") | |
return f"Error generating response with Gemini: {str(e)}" | |
def query_and_generate(self, query: str, n_results: int = 5, model: str = "openai") -> str: | |
"""Retrieve relevant documents and generate a response using the specified model""" | |
# Query vector store | |
documents = self.vector_store.query(query, n_results=n_results) | |
if not documents: | |
return "No relevant documents found to answer your question." | |
# Format context | |
context = self.format_context(documents) | |
# Generate response with the appropriate model | |
if model == "openai": | |
return self.generate_response_openai(query, context) | |
elif model == "gemini": | |
return self.generate_response_gemini(query, context) | |
else: | |
return f"Unknown model: {model}" | |
# Main function to run the application | |
def main(): | |
# Initialize the system with current directory as the Chroma location | |
config = Config( | |
local_dir=".", # Look for Chroma files in current directory | |
collection_name="markdown_docs" | |
) | |
try: | |
# Initialize vector store manager with existing collection | |
vector_store = VectorStoreManager(config) | |
# Initialize RAG system without API keys initially | |
rag_system = RAGSystem(vector_store) | |
# Create the Gradio interface | |
with gr.Blocks(title="Document RAG System") as app: | |
gr.Markdown("# Document RAG System") | |
with gr.Row(): | |
with gr.Column(scale=1): | |
# API Keys and model selection | |
model_choice = gr.Radio( | |
choices=["openai", "gemini"], | |
value="openai", | |
label="Choose LLM Provider", | |
info="Select which model to use (GPT-4o mini or Gemini 1.5 Flash)" | |
) | |
api_key_input = gr.Textbox( | |
label="API Key", | |
placeholder="Enter your API key here...", | |
type="password" | |
) | |
save_key_button = gr.Button("Save API Key", variant="primary") | |
api_status = gr.Markdown("") | |
# Search controls | |
num_results = gr.Slider( | |
minimum=1, | |
maximum=10, | |
value=10, | |
step=1, | |
label="Number of documents to retrieve" | |
) | |
# Database stats | |
gr.Markdown("### Database Statistics") | |
stats_display = gr.Textbox( | |
label="", | |
value=get_db_stats(vector_store), | |
lines=2 | |
) | |
refresh_button = gr.Button("Refresh Stats") | |
with gr.Column(scale=2): | |
# Query and response | |
query_input = gr.Textbox( | |
label="Your Question", | |
placeholder="Ask a question about your documents...", | |
lines=2 | |
) | |
query_button = gr.Button("Ask Question", variant="primary") | |
gr.Markdown("### Response") | |
response_output = gr.Markdown() | |
gr.Markdown("### Document Search Results") | |
search_output = gr.Markdown() | |
# Function to update API key based on selected model | |
def update_api_key(api_key, model): | |
if model == "openai": | |
success = rag_system.setup_openai(api_key) | |
model_name = "OpenAI GPT-4o mini" | |
else: | |
success = rag_system.setup_gemini(api_key) | |
model_name = "Google Gemini 1.5 Flash" | |
if success: | |
return f"✅ {model_name} API key configured successfully" | |
else: | |
return f"❌ Failed to configure {model_name} API key" | |
# Query function that returns both response and search results | |
def query_and_search(query, n_results, model): | |
# Get search results first | |
results = vector_store.query(query, n_results=int(n_results)) | |
# Format search results | |
formatted_results = [] | |
for i, res in enumerate(results): | |
metadata = res['metadata'] | |
title = metadata.get('title', metadata.get('filename', 'Unknown')) | |
preview = res['document'][:500] + '...' if len(res['document']) > 500 else res['document'] | |
formatted_results.append(f"**Result {i+1}** (Similarity: {res['score']:.2f})\n" | |
f"**Source:** {title}\n" | |
f"**Preview:**\n{preview}\n\n---\n") | |
search_output_text = "\n".join(formatted_results) if formatted_results else "No results found." | |
# Generate response if we have results | |
response = "No documents found to answer your question." | |
if results: | |
context = rag_system.format_context(results) | |
if model == "openai": | |
response = rag_system.generate_response_openai(query, context) | |
else: | |
response = rag_system.generate_response_gemini(query, context) | |
return response, search_output_text | |
# Set up events | |
save_key_button.click( | |
fn=update_api_key, | |
inputs=[api_key_input, model_choice], | |
outputs=api_status | |
) | |
query_button.click( | |
fn=query_and_search, | |
inputs=[query_input, num_results, model_choice], | |
outputs=[response_output, search_output] | |
) | |
refresh_button.click( | |
fn=lambda: get_db_stats(vector_store), | |
inputs=None, | |
outputs=stats_display | |
) | |
# Launch the interface | |
app.launch() | |
except Exception as e: | |
logger.error(f"Error initializing application: {e}") | |
print(f"Error: {e}") | |
sys.exit(1) | |
# Helper function to get database stats | |
def get_db_stats(vector_store): | |
"""Function to get vector store statistics""" | |
try: | |
stats = vector_store.get_statistics() | |
return f"Total documents: {stats.get('total_documents', 0)}" | |
except Exception as e: | |
logger.error(f"Error getting statistics: {e}") | |
return "Error getting database statistics" | |
if __name__ == "__main__": | |
main() |