Spaces:
Sleeping
Sleeping
| """ | |
| Groq Medical RAG System v2.0 | |
| FREE Groq Cloud API integration for advanced medical reasoning | |
| """ | |
| import os | |
| import time | |
| import logging | |
| import numpy as np | |
| from typing import List, Dict, Any, Optional, Tuple | |
| from dataclasses import dataclass | |
| from dotenv import load_dotenv | |
| from pathlib import Path | |
| import argparse | |
| import shutil | |
| import re | |
| # Langchain for document loading and splitting | |
| from langchain_community.document_loaders import UnstructuredMarkdownLoader | |
| from langchain.text_splitter import RecursiveCharacterTextSplitter | |
| # Sentence Transformers for re-ranking | |
| from sentence_transformers import CrossEncoder | |
| # Groq API integration | |
| from groq import Groq | |
| from tenacity import retry, stop_after_attempt, wait_fixed, before_sleep_log | |
| # Load environment variables from .env file | |
| load_dotenv() | |
| # Import our simplified components | |
| from .simple_vector_store import SimpleVectorStore, SearchResult | |
| class MedicalResponse: | |
| """Enhanced medical response structure""" | |
| answer: str | |
| confidence: float | |
| sources: List[str] | |
| query_time: float | |
| class GroqMedicalRAG: | |
| """Groq-powered Medical RAG System v2.0 - FREE LLM integration""" | |
| def __init__(self, | |
| vector_store_dir: str = "simple_vector_store", | |
| processed_docs_dir: str = "src/processed_markdown", | |
| groq_api_key: Optional[str] = None): | |
| """Initialize the Groq medical RAG system""" | |
| # Get the absolute path to the project root directory | |
| project_root = Path(__file__).parent.parent.resolve() | |
| self.vector_store_dir = project_root / vector_store_dir | |
| self.processed_docs_dir = project_root / processed_docs_dir | |
| # Initialize Groq client | |
| self.groq_api_key = groq_api_key or os.getenv("GROQ_API_KEY") | |
| if not self.groq_api_key: | |
| raise ValueError("GROQ_API_KEY environment variable not set. Get your free API key from https://console.groq.com/keys") | |
| self.groq_client = Groq(api_key=self.groq_api_key) | |
| self.model_name = "llama3-70b-8192" | |
| # Initialize Cross-Encoder for re-ranking | |
| self.reranker = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2') | |
| # Initialize components | |
| self.vector_store = None | |
| self.setup_logging() | |
| self._initialize_system() | |
| def setup_logging(self): | |
| """Setup logging for the RAG system""" | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' | |
| ) | |
| self.logger = logging.getLogger(__name__) | |
| def _initialize_system(self, force_recreate: bool = False): | |
| """Initialize the RAG system components""" | |
| try: | |
| # If forcing recreation, delete the old vector store | |
| if force_recreate and self.vector_store_dir.exists(): | |
| self.logger.warning(f"Recreating index as requested. Deleting {self.vector_store_dir}...") | |
| shutil.rmtree(self.vector_store_dir) | |
| # Initialize vector store | |
| self.vector_store = SimpleVectorStore(vector_store_dir=self.vector_store_dir) | |
| # Try to load existing vector store | |
| if not self.vector_store.load_vector_store(): | |
| self.logger.info("Creating new vector store from documents...") | |
| self._create_vector_store() | |
| else: | |
| self.logger.info("Loaded existing vector store") | |
| # Test Groq connection | |
| self._test_groq_connection() | |
| self.logger.info("Groq Medical RAG system initialized successfully") | |
| except Exception as e: | |
| self.logger.error(f"Error initializing RAG system: {e}") | |
| raise | |
| def _test_groq_connection(self): | |
| """Test Groq API connection with retry logic.""" | |
| try: | |
| self.groq_client.chat.completions.create( | |
| model=self.model_name, | |
| messages=[{"role": "user", "content": "Test"}], | |
| max_tokens=10, | |
| ) | |
| self.logger.info("✅ Groq API connection successful") | |
| except Exception as e: | |
| self.logger.error(f"❌ Groq API connection failed: {e}") | |
| raise | |
| def _create_vector_store(self): | |
| """Create vector store from processed markdown documents.""" | |
| self.logger.info(f"Checking for documents in {self.processed_docs_dir}...") | |
| doc_files = list(self.processed_docs_dir.glob("**/*.md")) | |
| if not doc_files: | |
| self.logger.error(f"No markdown files found in {self.processed_docs_dir}. Please run the enhanced_pdf_processor.py script first.") | |
| raise FileNotFoundError(f"No markdown files found in {self.processed_docs_dir}") | |
| self.logger.info(f"Found {len(doc_files)} markdown documents to process.") | |
| # Load documents using UnstructuredMarkdownLoader | |
| all_docs = [] | |
| for doc_path in doc_files: | |
| try: | |
| loader = UnstructuredMarkdownLoader(str(doc_path)) | |
| loaded_docs = loader.load() | |
| # We still need to ensure the 'source' is present for our context string. | |
| for doc in loaded_docs: | |
| if 'source' not in doc.metadata: | |
| doc.metadata['source'] = str(doc_path) | |
| all_docs.extend(loaded_docs) | |
| except Exception as e: | |
| self.logger.error(f"Error loading {doc_path}: {e}") | |
| if not all_docs: | |
| self.logger.error("Failed to load any documents. Vector store not created.") | |
| return | |
| # Split documents into chunks with smaller size and overlap | |
| text_splitter = RecursiveCharacterTextSplitter( | |
| chunk_size=1024, # Reduced from 2048 | |
| chunk_overlap=128, # Reduced from 256 | |
| separators=["\n\n", "\n", " ", ""] | |
| ) | |
| chunks = text_splitter.split_documents(all_docs) | |
| self.logger.info(f"Created {len(chunks)} chunks from {len(all_docs)} documents.") | |
| # Create embeddings and build index | |
| embeddings, count = self.vector_store.create_embeddings(chunks) | |
| self.vector_store.build_index(embeddings) | |
| self.vector_store.save_vector_store() | |
| self.logger.info(f"Created vector store with {count} embeddings.") | |
| def query(self, | |
| query: str, | |
| history: Optional[List[Dict[str, str]]] = None, | |
| k: int = 15, # Reduced from 30 | |
| top_n_rerank: int = 3, # Reduced from 5 | |
| use_llm: bool = True) -> MedicalResponse: | |
| """Query the Groq medical RAG system with re-ranking.""" | |
| start_time = time.time() | |
| # Stage 1: Initial retrieval from vector store | |
| docs = self.vector_store.search(query=query, k=k) | |
| if not docs: | |
| return self._create_no_results_response(query) | |
| # Stage 2: Re-ranking with Cross-Encoder | |
| sentence_pairs = [[query, doc.content] for doc in docs] | |
| scores = self.reranker.predict(sentence_pairs) | |
| # Combine docs with scores and sort | |
| doc_score_pairs = list(zip(docs, scores)) | |
| doc_score_pairs.sort(key=lambda x: x[1], reverse=True) | |
| # Select top N results after re-ranking | |
| reranked_docs = [pair[0] for pair in doc_score_pairs[:top_n_rerank]] | |
| reranked_scores = [pair[1] for pair in doc_score_pairs[:top_n_rerank]] | |
| # Prepare context with rich metadata for the LLM | |
| context_parts = [] | |
| for i, doc in enumerate(reranked_docs, 1): | |
| citation = doc.metadata.get('citation') | |
| if not citation: | |
| source_path = doc.metadata.get('source', 'Unknown') | |
| citation = Path(source_path).parent.name | |
| # Add reference number to citation | |
| context_parts.append(f"[{i}] Citation: {citation}\\n\\nContent: {doc.content}") | |
| context = "\\n\\n---\\n\\n".join(context_parts) | |
| confidence = self._calculate_confidence(reranked_scores, use_llm) | |
| # Use a set to get unique citations for display | |
| sources = list(set([ | |
| doc.metadata.get('citation', Path(doc.metadata.get('source', 'Unknown')).parent.name) | |
| for doc in reranked_docs | |
| ])) | |
| if use_llm: | |
| # Phase 4: Persona-driven, structured response generation | |
| system_prompt = ( | |
| "You are 'VedaMD', a world-class medical expert and a compassionate assistant for healthcare professionals in Sri Lanka. " | |
| "Your primary goal is to provide accurate, evidence-based clinical information based ONLY on the provided context, which is sourced from official Sri Lankan maternal health guidelines. " | |
| "Your tone should be professional, clear, and supportive.\\n\\n" | |
| "**CRITICAL INSTRUCTIONS:**\\n" | |
| "1. **Strictly Context-Bound:** Your answer MUST be based exclusively on the 'Content' provided for each source. Do not use any external knowledge or provide information not present in the context.\\n" | |
| "2. **Markdown Formatting:** Structure your answers for maximum clarity. Use markdown for formatting:\\n" | |
| " - Use headings (`##`) for main topics.\\n" | |
| " - Use bullet points (`-` or `*`) for lists of symptoms, recommendations, or steps.\\n" | |
| " - Use bold (`**text**`) to emphasize key terms, dosages, or critical warnings.\\n" | |
| "3. **Synthesize, Don't Just Copy:** Read all context pieces, synthesize the information, and provide a comprehensive answer. Do not repeat information.\\n" | |
| "4. **Scientific Citations:** Use numbered citations [1], [2], etc. in your answer text to reference specific information. At the end, list all sources under a 'References:' heading in scientific format:\\n" | |
| " [1] Title of Guideline/Document\\n" | |
| " [2] Title of Another Guideline/Document\\n" | |
| "5. **Disclaimer:** At the end of EVERY response, include the following disclaimer: '_This information is for clinical reference based on Sri Lankan guidelines and does not replace professional medical judgment._'" | |
| ) | |
| return self._create_llm_response(system_prompt, context, query, confidence, sources, start_time, history) | |
| else: | |
| # If not using LLM, return context directly | |
| return MedicalResponse( | |
| answer=context, | |
| confidence=confidence, | |
| sources=sources, | |
| query_time=time.time() - start_time | |
| ) | |
| def _create_llm_response(self, system_prompt: str, context: str, query: str, confidence: float, sources: List[str], start_time: float, history: Optional[List[Dict[str, str]]] = None) -> MedicalResponse: | |
| """Helper to generate response from LLM.""" | |
| try: | |
| messages = [ | |
| { | |
| "role": "system", | |
| "content": system_prompt, | |
| } | |
| ] | |
| # Add conversation history to the messages | |
| if history: | |
| messages.extend(history) | |
| # Add the current query | |
| messages.append({"role": "user", "content": f"Context:\\n{context}\\n\\nQuestion: {query}"}) | |
| chat_completion = self.groq_client.chat.completions.create( | |
| messages=messages, | |
| model=self.model_name, | |
| temperature=0.7, | |
| max_tokens=2048, | |
| top_p=1, | |
| stream=False | |
| ) | |
| response_content = chat_completion.choices[0].message.content | |
| return MedicalResponse( | |
| answer=response_content, | |
| confidence=confidence, | |
| sources=sources, | |
| query_time=time.time() - start_time, | |
| ) | |
| except Exception as e: | |
| self.logger.error(f"Error during Groq API call: {e}") | |
| return MedicalResponse( | |
| answer=f"Sorry, I encountered an error while generating the response: {e}", | |
| confidence=0, | |
| sources=sources, | |
| query_time=time.time() - start_time | |
| ) | |
| def _calculate_confidence(self, scores: List[float], use_llm: bool) -> float: | |
| """ | |
| Calculate confidence score based on re-ranked results. | |
| For LLM responses, we can be more optimistic. | |
| """ | |
| if not scores: | |
| return 0.0 | |
| # Simple average of scores, scaled | |
| avg_score = sum(scores) / len(scores) | |
| # Sigmoid-like scaling for better confidence representation | |
| confidence = 1 / (1 + np.exp(-avg_score)) | |
| if use_llm: | |
| return min(confidence * 1.2, 1.0) # Boost confidence for LLM | |
| return confidence | |
| def _create_no_results_response(self, query: str) -> MedicalResponse: | |
| """Helper for no results response""" | |
| return MedicalResponse( | |
| answer="No relevant documents found for your query. Please try rephrasing your question.", | |
| confidence=0, | |
| sources=[], | |
| query_time=0 | |
| ) | |
| def main(recreate_index: bool = False): | |
| """Main function to initialize and test the RAG system.""" | |
| print("Initializing Groq Medical RAG system...") | |
| try: | |
| rag_system = GroqMedicalRAG() | |
| if recreate_index: | |
| print("Recreating index as requested...") | |
| # Re-initialize with force_recreate=True | |
| rag_system._initialize_system(force_recreate=True) | |
| print("✅ Index recreated successfully.") | |
| return # Exit after recreating index | |
| print("✅ System initialized successfully.") | |
| # Example query for testing | |
| print("\\n--- Testing with an example query ---") | |
| query = "What is the management for puerperal sepsis?" | |
| print(f"Query: {query}") | |
| response = rag_system.query(query) | |
| print("\\n--- Response ---") | |
| print(f"Answer: {response.answer}") | |
| print(f"Confidence: {response.confidence:.2f}") | |
| print(f"Sources: {response.sources}") | |
| print(f"Query Time: {response.query_time:.2f}s") | |
| print("--------------------\\n") | |
| except Exception as e: | |
| print(f"An error occurred: {e}") | |
| if __name__ == "__main__": | |
| parser = argparse.ArgumentParser(description="Groq Medical RAG System CLI") | |
| parser.add_argument( | |
| "--recreate-index", | |
| action="store_true", | |
| help="If set, deletes the existing vector store and creates a new one." | |
| ) | |
| args = parser.parse_args() | |
| main(recreate_index=args.recreate_index) | |
| async def main_async(recreate_index: bool = False): | |
| # This function seems to be unused in the current context, but I'll add a pass to avoid syntax errors. | |
| pass |