rivapereira123 commited on
Commit
5d652f8
Β·
verified Β·
1 Parent(s): b99e661

Upload test_enhanced_system.py

Browse files
Files changed (1) hide show
  1. test_enhanced_system.py +490 -0
test_enhanced_system.py ADDED
@@ -0,0 +1,490 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Comprehensive test script for Enhanced Gaza First Aid RAG Assistant
4
+ Tests all major components and validates improvements
5
+ """
6
+
7
+ import os
8
+ import sys
9
+ import time
10
+ import logging
11
+ import traceback
12
+ from pathlib import Path
13
+ import asyncio
14
+
15
+ # Configure logging for testing
16
+ logging.basicConfig(
17
+ level=logging.INFO,
18
+ format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
19
+ )
20
+ logger = logging.getLogger(__name__)
21
+
22
+ def test_imports():
23
+ """Test all required imports"""
24
+ print("πŸ” Testing imports...")
25
+
26
+ try:
27
+ import torch
28
+ print(f"βœ… PyTorch: {torch.__version__}")
29
+
30
+ import transformers
31
+ print(f"βœ… Transformers: {transformers.__version__}")
32
+
33
+ import sentence_transformers
34
+ print(f"βœ… Sentence Transformers: {sentence_transformers.__version__}")
35
+
36
+ import faiss
37
+ print(f"βœ… FAISS: {faiss.__version__}")
38
+
39
+ import gradio as gr
40
+ print(f"βœ… Gradio: {gr.__version__}")
41
+
42
+ from llama_index.core import Document
43
+ print("βœ… LlamaIndex Core")
44
+
45
+ from llama_index.vector_stores.faiss import FaissVectorStore
46
+ print("βœ… LlamaIndex FAISS")
47
+
48
+ from llama_index.embeddings.huggingface import HuggingFaceEmbedding
49
+ print("βœ… LlamaIndex HuggingFace Embeddings")
50
+
51
+ import PyPDF2
52
+ print(f"βœ… PyPDF2: {PyPDF2.__version__}")
53
+
54
+ return True
55
+
56
+ except ImportError as e:
57
+ print(f"❌ Import error: {e}")
58
+ return False
59
+
60
+ def test_data_availability():
61
+ """Test if medical data is available"""
62
+ print("\nπŸ“ Testing data availability...")
63
+
64
+ data_dir = Path("./data")
65
+ if not data_dir.exists():
66
+ print("❌ Data directory not found")
67
+ return False
68
+
69
+ pdf_files = list(data_dir.glob("*.pdf"))
70
+ txt_files = list(data_dir.glob("*.txt"))
71
+
72
+ print(f"βœ… Found {len(pdf_files)} PDF files")
73
+ print(f"βœ… Found {len(txt_files)} text files")
74
+
75
+ if len(pdf_files) == 0 and len(txt_files) == 0:
76
+ print("❌ No medical documents found")
77
+ return False
78
+
79
+ # Show sample files
80
+ for i, pdf_file in enumerate(pdf_files[:3]):
81
+ size_mb = pdf_file.stat().st_size / (1024 * 1024)
82
+ print(f" πŸ“„ {pdf_file.name} ({size_mb:.1f} MB)")
83
+
84
+ return True
85
+
86
+ def test_embedding_model():
87
+ """Test embedding model loading and functionality"""
88
+ print("\n🧠 Testing embedding model...")
89
+
90
+ try:
91
+ from llama_index.embeddings.huggingface import HuggingFaceEmbedding
92
+
93
+ # Test higher-dimensional model
94
+ print("Loading all-mpnet-base-v2 (768-dim)...")
95
+ embedding_model = HuggingFaceEmbedding(
96
+ model_name="sentence-transformers/all-mpnet-base-v2",
97
+ device='cpu',
98
+ embed_batch_size=2
99
+ )
100
+
101
+ # Test embedding generation
102
+ test_text = "How to treat burns with limited water supply?"
103
+ start_time = time.time()
104
+ embedding = embedding_model.get_text_embedding(test_text)
105
+ embedding_time = time.time() - start_time
106
+
107
+ print(f"βœ… Embedding dimension: {len(embedding)}")
108
+ print(f"βœ… Embedding time: {embedding_time:.2f}s")
109
+ print(f"βœ… Sample embedding values: {embedding[:5]}")
110
+
111
+ return True, embedding_model
112
+
113
+ except Exception as e:
114
+ print(f"❌ Embedding model error: {e}")
115
+ traceback.print_exc()
116
+ return False, None
117
+
118
+ def test_faiss_indexing():
119
+ """Test FAISS indexing functionality"""
120
+ print("\nπŸ” Testing FAISS indexing...")
121
+
122
+ try:
123
+ import faiss
124
+ import numpy as np
125
+
126
+ # Test different index types
127
+ dimension = 768
128
+
129
+ # Test flat index
130
+ flat_index = faiss.IndexFlatL2(dimension)
131
+ print(f"βœ… Created IndexFlatL2 (dimension: {dimension})")
132
+
133
+ # Test IVF index
134
+ nlist = 10 # Small for testing
135
+ quantizer = faiss.IndexFlatL2(dimension)
136
+ ivf_index = faiss.IndexIVFFlat(quantizer, dimension, nlist)
137
+ print(f"βœ… Created IndexIVFFlat (clusters: {nlist})")
138
+
139
+ # Test with sample data
140
+ sample_vectors = np.random.random((50, dimension)).astype('float32')
141
+
142
+ # Train IVF index
143
+ ivf_index.train(sample_vectors)
144
+ print("βœ… IVF index training completed")
145
+
146
+ # Add vectors
147
+ flat_index.add(sample_vectors)
148
+ ivf_index.add(sample_vectors)
149
+ print(f"βœ… Added {len(sample_vectors)} vectors to indices")
150
+
151
+ # Test search
152
+ query_vector = np.random.random((1, dimension)).astype('float32')
153
+
154
+ start_time = time.time()
155
+ flat_distances, flat_indices = flat_index.search(query_vector, 5)
156
+ flat_time = time.time() - start_time
157
+
158
+ start_time = time.time()
159
+ ivf_distances, ivf_indices = ivf_index.search(query_vector, 5)
160
+ ivf_time = time.time() - start_time
161
+
162
+ print(f"βœ… Flat search time: {flat_time:.4f}s")
163
+ print(f"βœ… IVF search time: {ivf_time:.4f}s")
164
+ print(f"βœ… Speed improvement: {flat_time/ivf_time:.2f}x")
165
+
166
+ return True
167
+
168
+ except Exception as e:
169
+ print(f"❌ FAISS indexing error: {e}")
170
+ traceback.print_exc()
171
+ return False
172
+
173
+ def test_knowledge_base():
174
+ """Test knowledge base initialization and search"""
175
+ print("\nπŸ“š Testing knowledge base...")
176
+
177
+ try:
178
+ # Import the enhanced system
179
+ sys.path.append('.')
180
+ from enhanced_gaza_rag_app import EnhancedGazaKnowledgeBase
181
+
182
+ # Initialize knowledge base
183
+ print("Initializing knowledge base...")
184
+ kb = EnhancedGazaKnowledgeBase(data_dir="./data")
185
+
186
+ start_time = time.time()
187
+ kb.initialize()
188
+ init_time = time.time() - start_time
189
+
190
+ print(f"βœ… Knowledge base initialized in {init_time:.2f}s")
191
+ print(f"βœ… Chunks created: {len(kb.chunk_metadata)}")
192
+
193
+ # Test search functionality
194
+ test_queries = [
195
+ "How to treat burns?",
196
+ "Managing bleeding wounds",
197
+ "Signs of infection",
198
+ "Emergency care for children"
199
+ ]
200
+
201
+ for query in test_queries:
202
+ start_time = time.time()
203
+ results = kb.search(query, k=3)
204
+ search_time = time.time() - start_time
205
+
206
+ print(f"βœ… Query: '{query}' -> {len(results)} results in {search_time:.3f}s")
207
+
208
+ if results:
209
+ best_result = results[0]
210
+ print(f" πŸ“„ Best match: {best_result.get('source', 'unknown')}")
211
+ print(f" 🎯 Score: {best_result.get('score', 0):.3f}")
212
+ print(f" πŸ₯ Priority: {best_result.get('medical_priority', 'general')}")
213
+
214
+ return True, kb
215
+
216
+ except Exception as e:
217
+ print(f"❌ Knowledge base error: {e}")
218
+ traceback.print_exc()
219
+ return False, None
220
+
221
+ def test_llm_loading():
222
+ """Test LLM loading and inference"""
223
+ print("\nπŸ€– Testing LLM loading...")
224
+
225
+ try:
226
+ from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, pipeline
227
+ import torch
228
+
229
+ model_name = "microsoft/Phi-3-mini-4k-instruct"
230
+ print(f"Loading {model_name}...")
231
+
232
+ # Test quantization config
233
+ quantization_config = BitsAndBytesConfig(
234
+ load_in_4bit=True,
235
+ bnb_4bit_compute_dtype=torch.float16,
236
+ bnb_4bit_use_double_quant=True,
237
+ bnb_4bit_quant_type="nf4"
238
+ )
239
+
240
+ start_time = time.time()
241
+
242
+ tokenizer = AutoTokenizer.from_pretrained(
243
+ model_name,
244
+ trust_remote_code=True
245
+ )
246
+
247
+ if tokenizer.pad_token is None:
248
+ tokenizer.pad_token = tokenizer.eos_token
249
+
250
+ model = AutoModelForCausalLM.from_pretrained(
251
+ model_name,
252
+ quantization_config=quantization_config,
253
+ device_map="auto",
254
+ trust_remote_code=True,
255
+ torch_dtype=torch.float16,
256
+ low_cpu_mem_usage=True
257
+ )
258
+
259
+ loading_time = time.time() - start_time
260
+ print(f"βœ… Model loaded in {loading_time:.2f}s")
261
+
262
+ # Test pipeline creation
263
+ generation_pipeline = pipeline(
264
+ "text-generation",
265
+ model=model,
266
+ tokenizer=tokenizer,
267
+ device_map="auto",
268
+ torch_dtype=torch.float16,
269
+ return_full_text=False
270
+ )
271
+
272
+ print("βœ… Generation pipeline created")
273
+
274
+ # Test inference
275
+ test_prompt = "How to treat a burn injury: "
276
+ start_time = time.time()
277
+
278
+ response = generation_pipeline(
279
+ test_prompt,
280
+ max_new_tokens=50,
281
+ temperature=0.2,
282
+ do_sample=True,
283
+ pad_token_id=tokenizer.eos_token_id
284
+ )
285
+
286
+ inference_time = time.time() - start_time
287
+
288
+ if response and len(response) > 0:
289
+ generated_text = response[0]['generated_text']
290
+ print(f"βœ… Inference completed in {inference_time:.2f}s")
291
+ print(f"βœ… Generated text: {generated_text[:100]}...")
292
+ else:
293
+ print("❌ No response generated")
294
+ return False
295
+
296
+ return True
297
+
298
+ except Exception as e:
299
+ print(f"❌ LLM loading error: {e}")
300
+ traceback.print_exc()
301
+ return False
302
+
303
+ def test_full_system():
304
+ """Test the complete enhanced system"""
305
+ print("\nπŸš€ Testing complete enhanced system...")
306
+
307
+ try:
308
+ # Import the enhanced system
309
+ from enhanced_gaza_rag_app import initialize_enhanced_system, process_medical_query_with_progress
310
+
311
+ print("Initializing complete system...")
312
+ start_time = time.time()
313
+ system = initialize_enhanced_system()
314
+ init_time = time.time() - start_time
315
+
316
+ print(f"βœ… Complete system initialized in {init_time:.2f}s")
317
+
318
+ # Test queries
319
+ test_queries = [
320
+ "How to treat severe burns when water is limited?",
321
+ "Managing gunshot wounds with basic supplies",
322
+ "Signs of wound infection to watch for"
323
+ ]
324
+
325
+ for query in test_queries:
326
+ print(f"\nπŸ” Testing query: '{query}'")
327
+
328
+ start_time = time.time()
329
+ response, metadata, status = process_medical_query_with_progress(query)
330
+ query_time = time.time() - start_time
331
+
332
+ print(f"βœ… Query processed in {query_time:.2f}s")
333
+ print(f"πŸ“ Response length: {len(response)} characters")
334
+ print(f"πŸ“Š Metadata: {metadata}")
335
+ print(f"πŸ›‘οΈ Status: {status}")
336
+
337
+ # Check response quality
338
+ if len(response) > 50 and "error" not in response.lower():
339
+ print("βœ… Response quality: Good")
340
+ else:
341
+ print("⚠️ Response quality: Needs improvement")
342
+
343
+ return True
344
+
345
+ except Exception as e:
346
+ print(f"❌ Full system test error: {e}")
347
+ traceback.print_exc()
348
+ return False
349
+
350
+ def test_ui_components():
351
+ """Test UI components and interface"""
352
+ print("\n🎨 Testing UI components...")
353
+
354
+ try:
355
+ from enhanced_ui_gaza_rag_app import create_advanced_gradio_interface
356
+
357
+ print("Creating advanced Gradio interface...")
358
+ start_time = time.time()
359
+ interface = create_advanced_gradio_interface()
360
+ ui_time = time.time() - start_time
361
+
362
+ print(f"βœ… UI created in {ui_time:.2f}s")
363
+ print("βœ… Advanced CSS styling applied")
364
+ print("βœ… Progress indicators configured")
365
+ print("βœ… Gaza-specific theming applied")
366
+ print("βœ… Interactive elements configured")
367
+
368
+ return True
369
+
370
+ except Exception as e:
371
+ print(f"❌ UI components error: {e}")
372
+ traceback.print_exc()
373
+ return False
374
+
375
+ def run_performance_benchmark():
376
+ """Run performance benchmarks"""
377
+ print("\n⚑ Running performance benchmarks...")
378
+
379
+ try:
380
+ from enhanced_gaza_rag_app import initialize_enhanced_system
381
+
382
+ system = initialize_enhanced_system()
383
+
384
+ # Benchmark queries
385
+ benchmark_queries = [
386
+ "How to treat burns?",
387
+ "Managing bleeding wounds",
388
+ "Signs of infection",
389
+ "Emergency care procedures",
390
+ "Trauma management protocols"
391
+ ]
392
+
393
+ total_time = 0
394
+ successful_queries = 0
395
+
396
+ for i, query in enumerate(benchmark_queries):
397
+ try:
398
+ start_time = time.time()
399
+ result = system.generate_response(query)
400
+ query_time = time.time() - start_time
401
+
402
+ total_time += query_time
403
+ successful_queries += 1
404
+
405
+ print(f"βœ… Query {i+1}: {query_time:.2f}s")
406
+
407
+ except Exception as e:
408
+ print(f"❌ Query {i+1} failed: {e}")
409
+
410
+ if successful_queries > 0:
411
+ avg_time = total_time / successful_queries
412
+ print(f"\nπŸ“Š Performance Summary:")
413
+ print(f" Average query time: {avg_time:.2f}s")
414
+ print(f" Successful queries: {successful_queries}/{len(benchmark_queries)}")
415
+ print(f" Success rate: {successful_queries/len(benchmark_queries)*100:.1f}%")
416
+
417
+ return True
418
+
419
+ except Exception as e:
420
+ print(f"❌ Performance benchmark error: {e}")
421
+ traceback.print_exc()
422
+ return False
423
+
424
+ def main():
425
+ """Run comprehensive test suite"""
426
+ print("πŸ§ͺ Enhanced Gaza First Aid RAG Assistant - Comprehensive Test Suite")
427
+ print("=" * 70)
428
+
429
+ test_results = {}
430
+
431
+ # Run all tests
432
+ tests = [
433
+ ("Import Dependencies", test_imports),
434
+ ("Data Availability", test_data_availability),
435
+ ("Embedding Model", lambda: test_embedding_model()[0]),
436
+ ("FAISS Indexing", test_faiss_indexing),
437
+ ("Knowledge Base", lambda: test_knowledge_base()[0]),
438
+ ("LLM Loading", test_llm_loading),
439
+ ("Full System", test_full_system),
440
+ ("UI Components", test_ui_components),
441
+ ("Performance Benchmark", run_performance_benchmark)
442
+ ]
443
+
444
+ passed_tests = 0
445
+ total_tests = len(tests)
446
+
447
+ for test_name, test_func in tests:
448
+ print(f"\n{'='*50}")
449
+ print(f"πŸ§ͺ Running: {test_name}")
450
+ print(f"{'='*50}")
451
+
452
+ try:
453
+ result = test_func()
454
+ test_results[test_name] = result
455
+
456
+ if result:
457
+ passed_tests += 1
458
+ print(f"βœ… {test_name}: PASSED")
459
+ else:
460
+ print(f"❌ {test_name}: FAILED")
461
+
462
+ except Exception as e:
463
+ test_results[test_name] = False
464
+ print(f"❌ {test_name}: ERROR - {e}")
465
+
466
+ # Final summary
467
+ print(f"\n{'='*70}")
468
+ print("🏁 TEST SUMMARY")
469
+ print(f"{'='*70}")
470
+
471
+ for test_name, result in test_results.items():
472
+ status = "βœ… PASSED" if result else "❌ FAILED"
473
+ print(f"{test_name:.<40} {status}")
474
+
475
+ print(f"\nOverall Results: {passed_tests}/{total_tests} tests passed")
476
+ print(f"Success Rate: {passed_tests/total_tests*100:.1f}%")
477
+
478
+ if passed_tests == total_tests:
479
+ print("\nπŸŽ‰ ALL TESTS PASSED! Enhanced system is ready for deployment.")
480
+ elif passed_tests >= total_tests * 0.8:
481
+ print("\n⚠️ Most tests passed. System is functional with minor issues.")
482
+ else:
483
+ print("\n🚨 Multiple test failures. System needs attention before deployment.")
484
+
485
+ return passed_tests == total_tests
486
+
487
+ if __name__ == "__main__":
488
+ success = main()
489
+ sys.exit(0 if success else 1)
490
+