rivapereira123 commited on
Commit
4c38b03
Β·
verified Β·
1 Parent(s): b4078c2

Upload app (16).py

Browse files
Files changed (1) hide show
  1. app (16).py +1032 -0
app (16).py ADDED
@@ -0,0 +1,1032 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import json
4
+ import logging
5
+ import warnings
6
+ from pathlib import Path
7
+ from typing import List, Dict, Any, Optional, Tuple
8
+ import hashlib
9
+ import pickle
10
+ from datetime import datetime
11
+ import time
12
+ import asyncio
13
+ from concurrent.futures import ThreadPoolExecutor
14
+
15
+ # Suppress warnings for cleaner output
16
+ warnings.filterwarnings("ignore")
17
+
18
+ # Core dependencies
19
+ import gradio as gr
20
+ import numpy as np
21
+ import pandas as pd
22
+ from sentence_transformers import SentenceTransformer
23
+ import faiss
24
+ import torch
25
+ from transformers import (
26
+ AutoTokenizer,
27
+ AutoModelForCausalLM,
28
+ BitsAndBytesConfig,
29
+ pipeline
30
+ )
31
+
32
+ # Medical knowledge validation
33
+ import re
34
+
35
+ # Configure logging
36
+ logging.basicConfig(
37
+ level=logging.INFO,
38
+ format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
39
+ )
40
+ logger = logging.getLogger(__name__)
41
+
42
+ class MedicalFactChecker:
43
+ """Enhanced medical fact checker with faster validation"""
44
+
45
+ def __init__(self):
46
+ self.medical_facts = self._load_medical_facts()
47
+ self.contraindications = self._load_contraindications()
48
+ self.dosage_patterns = self._compile_dosage_patterns()
49
+ self.definitive_patterns = [
50
+ re.compile(r, re.IGNORECASE) for r in [
51
+ r'always\s+(?:use|take|apply)',
52
+ r'never\s+(?:use|take|apply)',
53
+ r'will\s+(?:cure|heal|fix)',
54
+ r'guaranteed\s+to',
55
+ r'completely\s+(?:safe|effective)'
56
+ ]
57
+ ]
58
+
59
+ def _load_medical_facts(self) -> Dict[str, Any]:
60
+ """Pre-loaded medical facts for Gaza context"""
61
+ return {
62
+ "burn_treatment": {
63
+ "cool_water": "Use clean, cool (not ice-cold) water for 10-20 minutes",
64
+ "no_ice": "Never apply ice directly to burns",
65
+ "clean_cloth": "Cover with clean, dry cloth if available"
66
+ },
67
+ "wound_care": {
68
+ "pressure": "Apply direct pressure to control bleeding",
69
+ "elevation": "Elevate injured limb if possible",
70
+ "clean_hands": "Clean hands before treating wounds when possible"
71
+ },
72
+ "infection_signs": {
73
+ "redness": "Increasing redness around wound",
74
+ "warmth": "Increased warmth at wound site",
75
+ "pus": "Yellow or green discharge",
76
+ "fever": "Fever may indicate systemic infection"
77
+ }
78
+ }
79
+
80
+ def _load_contraindications(self) -> Dict[str, List[str]]:
81
+ """Pre-loaded contraindications for common treatments"""
82
+ return {
83
+ "aspirin": ["children under 16", "bleeding disorders", "stomach ulcers"],
84
+ "ibuprofen": ["kidney disease", "heart failure", "stomach bleeding"],
85
+ "hydrogen_peroxide": ["deep wounds", "closed wounds", "eyes"],
86
+ "tourniquets": ["non-life-threatening bleeding", "without proper training"]
87
+ }
88
+
89
+ def _compile_dosage_patterns(self) -> List[re.Pattern]:
90
+ """Pre-compiled dosage patterns"""
91
+ patterns = [
92
+ r'\d+\s*mg\b', # milligrams
93
+ r'\d+\s*g\b', # grams
94
+ r'\d+\s*ml\b', # milliliters
95
+ r'\d+\s*tablets?\b', # tablets
96
+ r'\d+\s*times?\s+(?:per\s+)?day\b', # frequency
97
+ r'every\s+\d+\s+hours?\b' # intervals
98
+ ]
99
+ return [re.compile(pattern, re.IGNORECASE) for pattern in patterns]
100
+
101
+ def check_medical_accuracy(self, response: str, context: str) -> Dict[str, Any]:
102
+ """Enhanced medical accuracy check with Gaza-specific considerations"""
103
+ issues = []
104
+ warnings = []
105
+ accuracy_score = 0.0
106
+
107
+ # Check for contraindications (faster keyword matching)
108
+ response_lower = response.lower()
109
+ for medication, contra_list in self.contraindications.items():
110
+ if medication in response_lower:
111
+ for contra in contra_list:
112
+ if any(word in response_lower for word in contra.split()):
113
+ issues.append(f"Potential contraindication: {medication} with {contra}")
114
+ accuracy_score -= 0.3
115
+ break
116
+
117
+ # Context alignment using Jaccard similarity
118
+ if context:
119
+ resp_words = set(response_lower.split())
120
+ ctx_words = set(context.lower().split())
121
+ context_similarity = len(resp_words & ctx_words) / len(resp_words | ctx_words) if ctx_words else 0.0
122
+ if context_similarity < 0.5: # Lowered threshold for Gaza context
123
+ warnings.append(f"Low context similarity: {context_similarity:.2f}")
124
+ accuracy_score -= 0.1
125
+ else:
126
+ context_similarity = 0.0
127
+
128
+ # Gaza-specific resource checks
129
+ gaza_resources = ["clean water", "sterile", "hospital", "ambulance", "electricity"]
130
+ if any(resource in response_lower for resource in gaza_resources):
131
+ warnings.append("Consider resource limitations in Gaza context")
132
+ accuracy_score -= 0.05
133
+
134
+ # Unsupported claims check
135
+ for pattern in self.definitive_patterns:
136
+ if pattern.search(response):
137
+ issues.append(f"Unsupported definitive claim detected")
138
+ accuracy_score -= 0.4
139
+ break
140
+
141
+ # Dosage validation
142
+ for pattern in self.dosage_patterns:
143
+ if pattern.search(response):
144
+ warnings.append("Dosage detected - verify with professional")
145
+ accuracy_score -= 0.1
146
+ break
147
+
148
+ confidence_score = max(0.0, min(1.0, 0.8 + accuracy_score))
149
+
150
+ return {
151
+ "confidence_score": confidence_score,
152
+ "issues": issues,
153
+ "warnings": warnings,
154
+ "context_similarity": context_similarity,
155
+ "is_safe": len(issues) == 0 and confidence_score > 0.5
156
+ }
157
+
158
+ class OptimizedGazaKnowledgeBase:
159
+ """Optimized knowledge base that loads pre-made FAISS index and assets"""
160
+
161
+ def __init__(self, vector_store_dir: str = "./vector_store"):
162
+ self.vector_store_dir = Path(vector_store_dir)
163
+ self.faiss_index = None
164
+ self.embedding_model = None
165
+ self.chunks = []
166
+ self.metadata = []
167
+ self.is_initialized = False
168
+
169
+ def initialize(self):
170
+ """Load pre-made FAISS index and associated data"""
171
+ try:
172
+ logger.info("πŸ”„ Loading pre-made FAISS index and assets...")
173
+
174
+ # 1. Load FAISS index
175
+ index_path = self.vector_store_dir / "index.faiss"
176
+ if not index_path.exists():
177
+ raise FileNotFoundError(f"FAISS index not found at {index_path}")
178
+
179
+ self.faiss_index = faiss.read_index(str(index_path))
180
+ logger.info(f"βœ… Loaded FAISS index: {self.faiss_index.ntotal} vectors, {self.faiss_index.d} dimensions")
181
+
182
+ # 2. Load chunks
183
+ chunks_path = self.vector_store_dir / "chunks.txt"
184
+ if not chunks_path.exists():
185
+ raise FileNotFoundError(f"Chunks file not found at {chunks_path}")
186
+
187
+ with open(chunks_path, 'r', encoding='utf-8') as f:
188
+ lines = f.readlines()
189
+
190
+ # Parse chunks from the formatted file
191
+ current_chunk = ""
192
+ for line in lines:
193
+ line = line.strip()
194
+ if line.startswith("=== Chunk") and current_chunk:
195
+ self.chunks.append(current_chunk.strip())
196
+ current_chunk = ""
197
+ elif not line.startswith("===") and not line.startswith("Source:") and not line.startswith("Length:"):
198
+ current_chunk += line + " "
199
+
200
+ # Add the last chunk
201
+ if current_chunk:
202
+ self.chunks.append(current_chunk.strip())
203
+
204
+ logger.info(f"βœ… Loaded {len(self.chunks)} text chunks")
205
+
206
+ # 3. Load metadata
207
+ metadata_path = self.vector_store_dir / "metadata.pkl"
208
+ if metadata_path.exists():
209
+ with open(metadata_path, 'rb') as f:
210
+ metadata_dict = pickle.load(f)
211
+
212
+ if isinstance(metadata_dict, dict) and 'metadata' in metadata_dict:
213
+ self.metadata = metadata_dict['metadata']
214
+ logger.info(f"βœ… Loaded {len(self.metadata)} metadata entries")
215
+ else:
216
+ logger.warning("⚠️ Metadata format not recognized, using empty metadata")
217
+ self.metadata = [{}] * len(self.chunks)
218
+ else:
219
+ logger.warning("⚠️ No metadata file found, using empty metadata")
220
+ self.metadata = [{}] * len(self.chunks)
221
+
222
+ # 4. Initialize embedding model for query encoding
223
+ logger.info("πŸ”„ Loading embedding model for queries...")
224
+ self.embedding_model = SentenceTransformer('sentence-transformers/all-mpnet-base-v2')
225
+ logger.info("βœ… Embedding model loaded")
226
+
227
+ # 5. Verify data consistency
228
+ if len(self.chunks) != self.faiss_index.ntotal:
229
+ logger.warning(f"⚠️ Mismatch: {len(self.chunks)} chunks vs {self.faiss_index.ntotal} vectors")
230
+ # Trim chunks to match index size
231
+ self.chunks = self.chunks[:self.faiss_index.ntotal]
232
+ self.metadata = self.metadata[:self.faiss_index.ntotal]
233
+ logger.info(f"βœ… Trimmed to {len(self.chunks)} chunks to match index")
234
+
235
+ self.is_initialized = True
236
+ logger.info("πŸŽ‰ Knowledge base initialization complete!")
237
+
238
+ except Exception as e:
239
+ logger.error(f"❌ Failed to initialize knowledge base: {e}")
240
+ raise
241
+
242
+ def search(self, query: str, k: int = 5) -> List[Dict[str, Any]]:
243
+ """Search using pre-made FAISS index"""
244
+ if not self.is_initialized:
245
+ raise RuntimeError("Knowledge base not initialized")
246
+
247
+ try:
248
+ # 1. Encode query
249
+ query_embedding = self.embedding_model.encode([query])
250
+ query_vector = np.array(query_embedding, dtype=np.float32)
251
+
252
+ # 2. Search FAISS index
253
+ distances, indices = self.faiss_index.search(query_vector, k)
254
+
255
+ # 3. Prepare results
256
+ results = []
257
+ for i, (distance, idx) in enumerate(zip(distances[0], indices[0])):
258
+ if idx >= 0 and idx < len(self.chunks): # Valid index
259
+ chunk_metadata = self.metadata[idx] if idx < len(self.metadata) else {}
260
+
261
+ result = {
262
+ "text": self.chunks[idx],
263
+ "score": float(1.0 / (1.0 + distance)), # Convert distance to similarity score
264
+ "source": chunk_metadata.get("source", "unknown"),
265
+ "chunk_index": int(idx),
266
+ "distance": float(distance),
267
+ "metadata": chunk_metadata
268
+ }
269
+ results.append(result)
270
+
271
+ logger.info(f"πŸ” Search for '{query[:50]}...' returned {len(results)} results")
272
+ return results
273
+
274
+ except Exception as e:
275
+ logger.error(f"❌ Search error: {e}")
276
+ return []
277
+
278
+ def get_stats(self) -> Dict[str, Any]:
279
+ """Get knowledge base statistics"""
280
+ if not self.is_initialized:
281
+ return {"status": "not_initialized"}
282
+
283
+ return {
284
+ "status": "initialized",
285
+ "total_chunks": len(self.chunks),
286
+ "total_vectors": self.faiss_index.ntotal,
287
+ "embedding_dimension": self.faiss_index.d,
288
+ "index_type": type(self.faiss_index).__name__,
289
+ "sources": list(set(meta.get("source", "unknown") for meta in self.metadata))
290
+ }
291
+
292
+ class OptimizedGazaRAGSystem:
293
+ """Optimized RAG system using pre-made assets"""
294
+
295
+ def __init__(self, vector_store_dir: str = "./vector_store"):
296
+ self.knowledge_base = OptimizedGazaKnowledgeBase(vector_store_dir)
297
+ self.fact_checker = MedicalFactChecker()
298
+ self.llm = None
299
+ self.tokenizer = None
300
+ self.system_prompt = self._create_system_prompt()
301
+ self.generation_pipeline = None
302
+ self.response_cache = {}
303
+ self.executor = ThreadPoolExecutor(max_workers=2)
304
+
305
+ def initialize(self):
306
+ """Initialize the optimized RAG system"""
307
+ logger.info("πŸš€ Initializing Optimized Gaza RAG System...")
308
+ self.knowledge_base.initialize()
309
+ logger.info("βœ… Optimized Gaza RAG System ready!")
310
+
311
+ def _initialize_llm(self):
312
+ """Enhanced LLM initialization with better error handling"""
313
+ if self.llm is not None:
314
+ return
315
+
316
+ model_name = "microsoft/Phi-3-mini-4k-instruct"
317
+ try:
318
+ logger.info(f"πŸ”„ Loading LLM: {model_name}")
319
+
320
+ # Enhanced quantization configuration
321
+ quantization_config = BitsAndBytesConfig(
322
+ load_in_4bit=True,
323
+ bnb_4bit_use_double_quant=True,
324
+ bnb_4bit_quant_type="nf4",
325
+ bnb_4bit_compute_dtype=torch.float16,
326
+ )
327
+
328
+ self.tokenizer = AutoTokenizer.from_pretrained(
329
+ model_name,
330
+ trust_remote_code=True,
331
+ padding_side="left"
332
+ )
333
+
334
+ if self.tokenizer.pad_token is None:
335
+ self.tokenizer.pad_token = self.tokenizer.eos_token
336
+
337
+ self.llm = AutoModelForCausalLM.from_pretrained(
338
+ model_name,
339
+ quantization_config=quantization_config,
340
+ device_map="auto",
341
+ trust_remote_code=True,
342
+ torch_dtype=torch.float16,
343
+ low_cpu_mem_usage=True
344
+ )
345
+
346
+ self.generation_pipeline = pipeline(
347
+ "text-generation",
348
+ model=self.llm,
349
+ tokenizer=self.tokenizer,
350
+ device_map="auto",
351
+ torch_dtype=torch.float16,
352
+ return_full_text=False
353
+ )
354
+
355
+ logger.info("βœ… LLM loaded successfully")
356
+
357
+ except Exception as e:
358
+ logger.error(f"❌ Error loading primary model: {e}")
359
+ self._initialize_fallback_llm()
360
+
361
+ def _initialize_fallback_llm(self):
362
+ """Enhanced fallback model with better error handling"""
363
+ try:
364
+ logger.info("πŸ”„ Loading fallback model...")
365
+
366
+ fallback_model = "microsoft/DialoGPT-small"
367
+ self.tokenizer = AutoTokenizer.from_pretrained(fallback_model)
368
+ self.llm = AutoModelForCausalLM.from_pretrained(
369
+ fallback_model,
370
+ torch_dtype=torch.float32,
371
+ low_cpu_mem_usage=True
372
+ )
373
+
374
+ if self.tokenizer.pad_token is None:
375
+ self.tokenizer.pad_token = self.tokenizer.eos_token
376
+
377
+ self.generation_pipeline = pipeline(
378
+ "text-generation",
379
+ model=self.llm,
380
+ tokenizer=self.tokenizer,
381
+ return_full_text=False
382
+ )
383
+
384
+ logger.info("βœ… Fallback model loaded successfully")
385
+
386
+ except Exception as e:
387
+ logger.error(f"❌ Fallback model failed: {e}")
388
+ self.llm = None
389
+ self.generation_pipeline = None
390
+
391
+ def _create_system_prompt(self) -> str:
392
+ """Enhanced system prompt for Gaza context"""
393
+ return """You are a medical AI assistant specifically designed for Gaza healthcare workers operating under siege conditions.
394
+
395
+ CRITICAL GUIDELINES:
396
+ - Provide practical first aid guidance considering limited resources (water, electricity, medical supplies)
397
+ - Always prioritize patient safety and recommend professional medical help when available
398
+ - Consider Gaza's specific challenges: blockade, limited hospitals, frequent power outages
399
+ - Suggest alternative treatments when standard medical supplies are unavailable
400
+ - Never provide definitive diagnoses - only supportive care guidance
401
+ - Be culturally sensitive and aware of the humanitarian crisis context
402
+
403
+ RESOURCE CONSTRAINTS TO CONSIDER:
404
+ - Limited clean water availability
405
+ - Frequent electricity outages
406
+ - Restricted medical supply access
407
+ - Overwhelmed healthcare facilities
408
+ - Limited transportation for medical emergencies
409
+
410
+ Provide clear, actionable advice while emphasizing the need for professional medical care when possible."""
411
+
412
+ async def generate_response_async(self, query: str, progress_callback=None) -> Dict[str, Any]:
413
+ """Async response generation with progress tracking"""
414
+ start_time = time.time()
415
+
416
+ if progress_callback:
417
+ progress_callback(0.1, "πŸ” Checking cache...")
418
+
419
+ # Check cache first
420
+ query_hash = hashlib.md5(query.encode()).hexdigest()
421
+ if query_hash in self.response_cache:
422
+ cached_response = self.response_cache[query_hash]
423
+ cached_response["cached"] = True
424
+ cached_response["response_time"] = 0.1
425
+ if progress_callback:
426
+ progress_callback(1.0, "πŸ’Ύ Retrieved from cache!")
427
+ return cached_response
428
+
429
+ try:
430
+ if progress_callback:
431
+ progress_callback(0.2, "πŸ€– Initializing LLM...")
432
+
433
+ # Initialize LLM only when needed
434
+ if self.llm is None:
435
+ await asyncio.get_event_loop().run_in_executor(
436
+ self.executor, self._initialize_llm
437
+ )
438
+
439
+ if progress_callback:
440
+ progress_callback(0.4, "πŸ” Searching knowledge base...")
441
+
442
+ # Enhanced knowledge retrieval using pre-made index
443
+ search_results = await asyncio.get_event_loop().run_in_executor(
444
+ self.executor, self.knowledge_base.search, query, 5
445
+ )
446
+
447
+ if progress_callback:
448
+ progress_callback(0.6, "πŸ“ Preparing context...")
449
+
450
+ context = self._prepare_context(search_results)
451
+
452
+ if progress_callback:
453
+ progress_callback(0.8, "🧠 Generating response...")
454
+
455
+ # Generate response
456
+ response = await asyncio.get_event_loop().run_in_executor(
457
+ self.executor, self._generate_response, query, context
458
+ )
459
+
460
+ if progress_callback:
461
+ progress_callback(0.9, "πŸ›‘οΈ Validating safety...")
462
+
463
+ # Enhanced safety check
464
+ safety_check = self.fact_checker.check_medical_accuracy(response, context)
465
+
466
+ # Prepare final response
467
+ final_response = self._prepare_final_response(
468
+ response,
469
+ search_results,
470
+ safety_check,
471
+ time.time() - start_time
472
+ )
473
+
474
+ # Cache the response (limit cache size)
475
+ if len(self.response_cache) < 100:
476
+ self.response_cache[query_hash] = final_response
477
+
478
+ if progress_callback:
479
+ progress_callback(1.0, "βœ… Complete!")
480
+
481
+ return final_response
482
+
483
+ except Exception as e:
484
+ logger.error(f"❌ Error generating response: {e}")
485
+ if progress_callback:
486
+ progress_callback(1.0, f"❌ Error: {str(e)}")
487
+ return self._create_error_response(str(e))
488
+
489
+ def _prepare_context(self, search_results: List[Dict[str, Any]]) -> str:
490
+ """Enhanced context preparation with better formatting"""
491
+ if not search_results:
492
+ return "No specific medical guidance found in knowledge base. Provide general first aid principles."
493
+
494
+ context_parts = []
495
+ for i, result in enumerate(search_results, 1):
496
+ source = result.get('source', 'unknown')
497
+ text = result.get('text', '')
498
+ score = result.get('score', 0.0)
499
+
500
+ # Truncate long text but preserve important information
501
+ if len(text) > 400:
502
+ text = text[:400] + "..."
503
+
504
+ context_parts.append(f"[Source {i}: {source} - Relevance: {score:.2f}]\n{text}")
505
+
506
+ return "\n\n".join(context_parts)
507
+
508
+ def _generate_response(self, query: str, context: str) -> str:
509
+ """Enhanced response generation using model.generate() to avoid DynamicCache errors"""
510
+ if self.llm is None or self.tokenizer is None:
511
+ return self._generate_fallback_response(query, context)
512
+
513
+ # Build prompt with Gaza-specific context
514
+ prompt = f"""{self.system_prompt}
515
+
516
+ MEDICAL KNOWLEDGE CONTEXT:
517
+ {context}
518
+
519
+ PATIENT QUESTION: {query}
520
+
521
+ RESPONSE (provide practical, Gaza-appropriate medical guidance):"""
522
+
523
+ try:
524
+ # Tokenize and move to correct device
525
+ inputs = self.tokenizer(prompt, return_tensors="pt", truncation=True, max_length=2048)
526
+ if hasattr(self.llm, 'device'):
527
+ inputs = inputs.to(self.llm.device)
528
+
529
+ # Generate the response
530
+ with torch.no_grad():
531
+ outputs = self.llm.generate(
532
+ **inputs,
533
+ max_new_tokens=300,
534
+ temperature=0.3,
535
+ pad_token_id=self.tokenizer.eos_token_id,
536
+ do_sample=True,
537
+ repetition_penalty=1.15,
538
+ no_repeat_ngram_size=3
539
+ )
540
+
541
+ # Decode and clean up
542
+ response_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
543
+
544
+ # Extract only the generated part
545
+ if "RESPONSE (provide practical, Gaza-appropriate medical guidance):" in response_text:
546
+ response_text = response_text.split("RESPONSE (provide practical, Gaza-appropriate medical guidance):")[1]
547
+
548
+ # Clean up the response
549
+ lines = response_text.split('\n')
550
+ unique_lines = []
551
+ for line in lines:
552
+ line = line.strip()
553
+ if line and line not in unique_lines and len(line) > 10: # Filter out very short lines
554
+ unique_lines.append(line)
555
+
556
+ return '\n'.join(unique_lines[:10]) # Limit to 10 lines
557
+
558
+ except Exception as e:
559
+ logger.error(f"❌ Error in LLM generate(): {e}")
560
+ return self._generate_fallback_response(query, context)
561
+
562
+ def _generate_fallback_response(self, query: str, context: str) -> str:
563
+ """Enhanced fallback response with Gaza-specific guidance"""
564
+ gaza_guidance = {
565
+ "burn": "For burns: Use clean, cool water if available. If water is scarce, use clean cloth. Avoid ice. Seek medical help urgently.",
566
+ "bleeding": "For bleeding: Apply direct pressure with clean cloth. Elevate if possible. If severe, seek immediate medical attention.",
567
+ "wound": "For wounds: Clean hands if possible. Apply pressure to stop bleeding. Cover with clean material. Watch for infection signs.",
568
+ "infection": "Signs of infection: Redness, warmth, swelling, pus, fever. Seek medical care immediately if available.",
569
+ "pain": "For pain management: Rest, elevation, cold/warm compress as appropriate. Avoid aspirin in children."
570
+ }
571
+
572
+ query_lower = query.lower()
573
+ for condition, guidance in gaza_guidance.items():
574
+ if condition in query_lower:
575
+ return f"{guidance}\n\nContext from medical sources:\n{context[:200]}..."
576
+
577
+ return f"Medical guidance for: {query}\n\nGeneral advice: Prioritize safety, seek professional help when available, consider resource limitations in Gaza.\n\nRelevant information:\n{context[:300]}..."
578
+
579
+ def _prepare_final_response(
580
+ self,
581
+ response: str,
582
+ search_results: List[Dict[str, Any]],
583
+ safety_check: Dict[str, Any],
584
+ response_time: float
585
+ ) -> Dict[str, Any]:
586
+ """Enhanced final response preparation with more metadata"""
587
+
588
+ # Add safety warnings if needed
589
+ if not safety_check["is_safe"]:
590
+ response = f"⚠️ MEDICAL CAUTION: {response}\n\n🚨 Please verify this guidance with a medical professional when possible."
591
+
592
+ # Add Gaza-specific disclaimer
593
+ response += "\n\nπŸ“ Gaza Context: This guidance considers resource limitations. Adapt based on available supplies and seek professional medical care when accessible."
594
+
595
+ # Extract unique sources
596
+ sources = list(set(res.get("source", "unknown") for res in search_results)) if search_results else []
597
+
598
+ # Calculate confidence based on multiple factors
599
+ base_confidence = safety_check.get("confidence_score", 0.5)
600
+ context_bonus = 0.1 if search_results else 0.0
601
+ safety_penalty = 0.2 if not safety_check.get("is_safe", True) else 0.0
602
+
603
+ final_confidence = max(0.0, min(1.0, base_confidence + context_bonus - safety_penalty))
604
+
605
+ return {
606
+ "response": response,
607
+ "confidence": final_confidence,
608
+ "sources": sources,
609
+ "search_results_count": len(search_results),
610
+ "safety_issues": safety_check.get("issues", []),
611
+ "safety_warnings": safety_check.get("warnings", []),
612
+ "response_time": round(response_time, 2),
613
+ "timestamp": datetime.now().isoformat()[:19],
614
+ "cached": False
615
+ }
616
+
617
+ def _create_error_response(self, error_msg: str) -> Dict[str, Any]:
618
+ """Enhanced error response with helpful information"""
619
+ return {
620
+ "response": f"⚠️ System Error: Unable to process your medical query at this time.\n\nError: {error_msg}\n\n🚨 For immediate medical emergencies, seek professional help directly.\n\nπŸ“ž Gaza Emergency Numbers:\n- Palestinian Red Crescent: 101\n- Civil Defense: 102",
621
+ "confidence": 0.0,
622
+ "sources": [],
623
+ "search_results_count": 0,
624
+ "safety_issues": ["System error occurred"],
625
+ "safety_warnings": ["Unable to validate medical accuracy"],
626
+ "response_time": 0.0,
627
+ "timestamp": datetime.now().isoformat()[:19],
628
+ "cached": False,
629
+ "error": True
630
+ }
631
+
632
+ # Global system instance
633
+ optimized_rag_system = None
634
+
635
+ def initialize_optimized_system(vector_store_dir: str = "./vector_store"):
636
+ """Initialize optimized system with pre-made assets"""
637
+ global optimized_rag_system
638
+ if optimized_rag_system is None:
639
+ try:
640
+ optimized_rag_system = OptimizedGazaRAGSystem(vector_store_dir)
641
+ optimized_rag_system.initialize()
642
+ logger.info("βœ… Optimized Gaza RAG System initialized successfully")
643
+ except Exception as e:
644
+ logger.error(f"❌ Failed to initialize optimized system: {e}")
645
+ raise
646
+ return optimized_rag_system
647
+
648
+ def process_medical_query_with_progress(query: str, progress=gr.Progress()) -> Tuple[str, str, str]:
649
+ """Enhanced query processing with detailed progress tracking and status updates"""
650
+ if not query.strip():
651
+ return "Please enter a medical question.", "", "⚠️ No query provided"
652
+
653
+ try:
654
+ # Initialize system with progress
655
+ progress(0.05, desc="πŸ”§ Initializing optimized system...")
656
+ system = initialize_optimized_system()
657
+
658
+ # Create async event loop for progress tracking
659
+ loop = asyncio.new_event_loop()
660
+ asyncio.set_event_loop(loop)
661
+
662
+ def progress_callback(value, desc):
663
+ progress(value, desc=desc)
664
+
665
+ try:
666
+ # Run async generation with progress
667
+ result = loop.run_until_complete(
668
+ system.generate_response_async(query, progress_callback)
669
+ )
670
+ finally:
671
+ loop.close()
672
+
673
+ # Prepare response with enhanced metadata
674
+ response = result["response"]
675
+
676
+ # Prepare detailed metadata
677
+ metadata_parts = [
678
+ f"🎯 Confidence: {result['confidence']:.1%}",
679
+ f"⏱️ Response: {result['response_time']}s",
680
+ f"πŸ“š Sources: {result['search_results_count']} found"
681
+ ]
682
+
683
+ if result.get('cached'):
684
+ metadata_parts.append("πŸ’Ύ Cached")
685
+
686
+ if result.get('sources'):
687
+ metadata_parts.append(f"πŸ“– Refs: {', '.join(result['sources'][:2])}")
688
+
689
+ metadata = " | ".join(metadata_parts)
690
+
691
+ # Prepare status with warnings/issues
692
+ status_parts = []
693
+ if result.get('safety_warnings'):
694
+ status_parts.append(f"⚠️ {len(result['safety_warnings'])} warnings")
695
+ if result.get('safety_issues'):
696
+ status_parts.append(f"🚨 {len(result['safety_issues'])} issues")
697
+ if not status_parts:
698
+ status_parts.append("βœ… Safe response")
699
+
700
+ status = " | ".join(status_parts)
701
+
702
+ return response, metadata, status
703
+
704
+ except Exception as e:
705
+ logger.error(f"❌ Error processing query: {e}")
706
+ error_response = f"⚠️ Error processing your query: {str(e)}\n\n🚨 For medical emergencies, seek immediate professional help."
707
+ error_metadata = f"❌ Error at {datetime.now().strftime('%H:%M:%S')}"
708
+ error_status = "🚨 System error occurred"
709
+ return error_response, error_metadata, error_status
710
+
711
+ def get_system_stats() -> str:
712
+ """Get system statistics for display"""
713
+ try:
714
+ system = initialize_optimized_system()
715
+ stats = system.knowledge_base.get_stats()
716
+
717
+ if stats["status"] == "initialized":
718
+ return f"""
719
+ πŸ“Š **System Statistics:**
720
+ - Status: βœ… Initialized
721
+ - Total Chunks: {stats['total_chunks']:,}
722
+ - Vector Dimension: {stats['embedding_dimension']}
723
+ - Index Type: {stats['index_type']}
724
+ - Sources: {len(stats['sources'])} documents
725
+ - Available Sources: {', '.join(stats['sources'][:5])}{'...' if len(stats['sources']) > 5 else ''}
726
+ """
727
+ else:
728
+ return "πŸ“Š System Status: ❌ Not Initialized"
729
+ except Exception as e:
730
+ return f"πŸ“Š System Status: ❌ Error - {str(e)}"
731
+
732
+ def create_optimized_gradio_interface():
733
+ """Create optimized Gradio interface with enhanced features"""
734
+
735
+ # Enhanced CSS with medical theme
736
+ css = """
737
+ @import url('https://fonts.googleapis.com/css2?family=Inter:wght@300;400;500;600;700&display=swap');
738
+
739
+ * {
740
+ font-family: 'Inter', sans-serif !important;
741
+ }
742
+
743
+ .gradio-container {
744
+ background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
745
+ min-height: 100vh;
746
+ }
747
+
748
+ .main-container {
749
+ background: rgba(255, 255, 255, 0.95);
750
+ backdrop-filter: blur(10px);
751
+ border-radius: 20px;
752
+ padding: 30px;
753
+ margin: 20px;
754
+ box-shadow: 0 20px 40px rgba(0,0,0,0.1);
755
+ border: 1px solid rgba(255,255,255,0.2);
756
+ }
757
+
758
+ .header-section {
759
+ background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
760
+ color: white;
761
+ border-radius: 15px;
762
+ padding: 25px;
763
+ margin-bottom: 25px;
764
+ text-align: center;
765
+ box-shadow: 0 10px 30px rgba(102, 126, 234, 0.3);
766
+ }
767
+
768
+ .query-container {
769
+ background: linear-gradient(135deg, #f8f9ff 0%, #e8f2ff 100%);
770
+ border-radius: 15px;
771
+ padding: 20px;
772
+ margin: 15px 0;
773
+ border: 2px solid #667eea;
774
+ transition: all 0.3s ease;
775
+ }
776
+
777
+ .response-container {
778
+ background: linear-gradient(135deg, #fff 0%, #f8f9ff 100%);
779
+ border-radius: 15px;
780
+ padding: 20px;
781
+ margin: 15px 0;
782
+ border: 2px solid #4CAF50;
783
+ min-height: 300px;
784
+ }
785
+
786
+ .submit-btn {
787
+ background: linear-gradient(135deg, #4CAF50 0%, #45a049 100%) !important;
788
+ color: white !important;
789
+ border: none !important;
790
+ border-radius: 12px !important;
791
+ padding: 15px 30px !important;
792
+ font-size: 16px !important;
793
+ font-weight: 600 !important;
794
+ cursor: pointer !important;
795
+ transition: all 0.3s ease !important;
796
+ box-shadow: 0 6px 20px rgba(76, 175, 80, 0.3) !important;
797
+ }
798
+
799
+ .submit-btn:hover {
800
+ transform: translateY(-3px) !important;
801
+ box-shadow: 0 10px 30px rgba(76, 175, 80, 0.4) !important;
802
+ }
803
+
804
+ .stats-container {
805
+ background: linear-gradient(135deg, #e3f2fd 0%, #bbdefb 100%);
806
+ border-radius: 12px;
807
+ padding: 15px;
808
+ margin: 10px 0;
809
+ border-left: 5px solid #2196F3;
810
+ font-size: 14px;
811
+ }
812
+ """
813
+
814
+ with gr.Blocks(
815
+ css=css,
816
+ title="πŸ₯ Optimized Gaza First Aid Assistant",
817
+ theme=gr.themes.Soft(
818
+ primary_hue="blue",
819
+ secondary_hue="green",
820
+ neutral_hue="slate"
821
+ )
822
+ ) as interface:
823
+
824
+ # Header Section
825
+ with gr.Row(elem_classes=["main-container"]):
826
+ gr.HTML("""
827
+ <div class="header-section">
828
+ <h1 style="margin: 0; font-size: 2.5em; font-weight: 700;">
829
+ πŸ₯ Optimized Gaza First Aid Assistant
830
+ </h1>
831
+ <h2 style="margin: 10px 0 0 0; font-size: 1.2em; font-weight: 400; opacity: 0.9;">
832
+ Powered by Pre-computed FAISS Index & 768-dim Embeddings
833
+ </h2>
834
+ <p style="margin: 15px 0 0 0; font-size: 1em; opacity: 0.8;">
835
+ Lightning-fast medical guidance using pre-processed knowledge base
836
+ </p>
837
+ </div>
838
+ """)
839
+
840
+ # System Stats
841
+ with gr.Row(elem_classes=["main-container"]):
842
+ with gr.Group(elem_classes=["stats-container"]):
843
+ stats_display = gr.Markdown(
844
+ value=get_system_stats(),
845
+ label="πŸ“Š System Status"
846
+ )
847
+
848
+ # Main Interface
849
+ with gr.Row(elem_classes=["main-container"]):
850
+ with gr.Column(scale=2):
851
+ # Query Input Section
852
+ with gr.Group(elem_classes=["query-container"]):
853
+ gr.Markdown("### 🩺 Medical Query Input")
854
+ query_input = gr.Textbox(
855
+ label="Describe your medical situation",
856
+ placeholder="Enter your first aid question or describe the medical emergency...",
857
+ lines=4
858
+ )
859
+
860
+ with gr.Row():
861
+ submit_btn = gr.Button(
862
+ "πŸ” Get Medical Guidance",
863
+ variant="primary",
864
+ elem_classes=["submit-btn"],
865
+ scale=3
866
+ )
867
+ clear_btn = gr.Button(
868
+ "πŸ—‘οΈ Clear",
869
+ variant="secondary",
870
+ scale=1
871
+ )
872
+
873
+ with gr.Column(scale=1):
874
+ # Quick Access
875
+ gr.Markdown("""
876
+ ### ⚑ Optimized Features
877
+
878
+ **πŸš€ Performance:**
879
+ - Pre-computed FAISS index
880
+ - 768-dimensional embeddings
881
+ - Lightning-fast search
882
+ - Optimized for Gaza context
883
+
884
+ **πŸ“š Knowledge Base:**
885
+ - WHO medical protocols
886
+ - ICRC war surgery guides
887
+ - MSF field manuals
888
+ - Gaza-specific adaptations
889
+
890
+ **πŸ›‘οΈ Safety Features:**
891
+ - Real-time fact checking
892
+ - Contraindication detection
893
+ - Gaza resource warnings
894
+ - Professional disclaimers
895
+ """)
896
+
897
+ # Response Section
898
+ with gr.Row(elem_classes=["main-container"]):
899
+ with gr.Column():
900
+ # Main Response
901
+ with gr.Group(elem_classes=["response-container"]):
902
+ gr.Markdown("### 🩹 Medical Guidance Response")
903
+ response_output = gr.Textbox(
904
+ label="AI Medical Guidance",
905
+ lines=15,
906
+ interactive=False,
907
+ placeholder="Your medical guidance will appear here..."
908
+ )
909
+
910
+ # Metadata and Status
911
+ with gr.Row():
912
+ with gr.Column(scale=1):
913
+ metadata_output = gr.Textbox(
914
+ label="πŸ“Š Response Metadata",
915
+ lines=2,
916
+ interactive=False,
917
+ placeholder="Response metadata will appear here..."
918
+ )
919
+
920
+ with gr.Column(scale=1):
921
+ status_output = gr.Textbox(
922
+ label="πŸ›‘οΈ Safety Status",
923
+ lines=2,
924
+ interactive=False,
925
+ placeholder="Safety validation status will appear here..."
926
+ )
927
+
928
+ # Examples Section
929
+ with gr.Row(elem_classes=["main-container"]):
930
+ gr.Markdown("### πŸ’‘ Example Medical Scenarios")
931
+
932
+ example_queries = [
933
+ "How to treat severe burns when clean water is extremely limited?",
934
+ "Managing gunshot wounds with only basic household supplies",
935
+ "Recognizing and treating infection in wounds without antibiotics",
936
+ "Emergency care for children during extended power outages",
937
+ "Treating compound fractures without proper medical equipment"
938
+ ]
939
+
940
+ gr.Examples(
941
+ examples=example_queries,
942
+ inputs=query_input,
943
+ label="Click any example to try it:",
944
+ examples_per_page=5
945
+ )
946
+
947
+ # Event Handlers
948
+ submit_btn.click(
949
+ process_medical_query_with_progress,
950
+ inputs=query_input,
951
+ outputs=[response_output, metadata_output, status_output],
952
+ show_progress=True
953
+ )
954
+
955
+ query_input.submit(
956
+ process_medical_query_with_progress,
957
+ inputs=query_input,
958
+ outputs=[response_output, metadata_output, status_output],
959
+ show_progress=True
960
+ )
961
+
962
+ clear_btn.click(
963
+ lambda: ("", "", "", ""),
964
+ outputs=[query_input, response_output, metadata_output, status_output]
965
+ )
966
+
967
+ # Refresh stats button
968
+ refresh_stats_btn = gr.Button("πŸ”„ Refresh System Stats", variant="secondary")
969
+ refresh_stats_btn.click(
970
+ lambda: get_system_stats(),
971
+ outputs=stats_display
972
+ )
973
+
974
+ return interface
975
+
976
+ def main():
977
+ """Enhanced main function with optimized system initialization"""
978
+ logger.info("πŸš€ Starting Optimized Gaza First Aid Assistant")
979
+
980
+ try:
981
+ # Check for vector store directory
982
+ vector_store_dir = "./vector_store"
983
+ if not Path(vector_store_dir).exists():
984
+ # Try alternative paths
985
+ alt_paths = ["./results/vector_store", "./results/vector_store_extracted"]
986
+ for alt_path in alt_paths:
987
+ if Path(alt_path).exists():
988
+ vector_store_dir = alt_path
989
+ logger.info(f"πŸ“ Found vector store at: {vector_store_dir}")
990
+ break
991
+ else:
992
+ raise FileNotFoundError("Vector store directory not found. Please ensure pre-made assets are available.")
993
+
994
+ # System initialization with detailed logging
995
+ logger.info(f"πŸ”§ Loading optimized system from: {vector_store_dir}")
996
+ system = initialize_optimized_system(vector_store_dir)
997
+
998
+ # Verify system components
999
+ stats = system.knowledge_base.get_stats()
1000
+ logger.info(f"βœ… Knowledge base loaded: {stats['total_chunks']} chunks, {stats['embedding_dimension']}D")
1001
+ logger.info(f"βœ… Sources: {len(stats['sources'])} documents")
1002
+ logger.info("βœ… Medical fact checker ready")
1003
+ logger.info("βœ… Optimized FAISS indexing active")
1004
+
1005
+ # Create and launch optimized interface
1006
+ logger.info("🎨 Creating optimized Gradio interface...")
1007
+ interface = create_optimized_gradio_interface()
1008
+
1009
+ logger.info("🌐 Launching optimized interface...")
1010
+ interface.launch(
1011
+ server_name="0.0.0.0",
1012
+ server_port=7860,
1013
+ share=False,
1014
+ max_threads=6,
1015
+ show_error=True,
1016
+ quiet=False
1017
+ )
1018
+
1019
+ except Exception as e:
1020
+ logger.error(f"❌ Failed to start Optimized Gaza First Aid Assistant: {e}")
1021
+ print(f"\n🚨 STARTUP ERROR: {e}")
1022
+ print("\nπŸ”§ Troubleshooting Steps:")
1023
+ print("1. Ensure vector_store directory exists with index.faiss, chunks.txt, and metadata.pkl")
1024
+ print("2. Check if all dependencies are installed: pip install -r requirements.txt")
1025
+ print("3. Verify sufficient memory is available (minimum 4GB RAM recommended)")
1026
+ print("4. Check system logs for detailed error information")
1027
+ print("\nπŸ“ž For technical support, check the application logs above.")
1028
+ sys.exit(1)
1029
+
1030
+ if __name__ == "__main__":
1031
+ main()
1032
+