Spaces:
Sleeping
Sleeping
#!/usr/bin/env python3 | |
""" | |
Comprehensive test script for Enhanced Gaza First Aid RAG Assistant | |
Tests all major components and validates improvements | |
""" | |
import os | |
import sys | |
import time | |
import logging | |
import traceback | |
from pathlib import Path | |
import asyncio | |
# Configure logging for testing | |
logging.basicConfig( | |
level=logging.INFO, | |
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' | |
) | |
logger = logging.getLogger(__name__) | |
def test_imports(): | |
"""Test all required imports""" | |
print("π Testing imports...") | |
try: | |
import torch | |
print(f"β PyTorch: {torch.__version__}") | |
import transformers | |
print(f"β Transformers: {transformers.__version__}") | |
import sentence_transformers | |
print(f"β Sentence Transformers: {sentence_transformers.__version__}") | |
import faiss | |
print(f"β FAISS: {faiss.__version__}") | |
import gradio as gr | |
print(f"β Gradio: {gr.__version__}") | |
from llama_index.core import Document | |
print("β LlamaIndex Core") | |
from llama_index.vector_stores.faiss import FaissVectorStore | |
print("β LlamaIndex FAISS") | |
from llama_index.embeddings.huggingface import HuggingFaceEmbedding | |
print("β LlamaIndex HuggingFace Embeddings") | |
import PyPDF2 | |
print(f"β PyPDF2: {PyPDF2.__version__}") | |
return True | |
except ImportError as e: | |
print(f"β Import error: {e}") | |
return False | |
def test_data_availability(): | |
"""Test if medical data is available""" | |
print("\nπ Testing data availability...") | |
data_dir = Path("./data") | |
if not data_dir.exists(): | |
print("β Data directory not found") | |
return False | |
pdf_files = list(data_dir.glob("*.pdf")) | |
txt_files = list(data_dir.glob("*.txt")) | |
print(f"β Found {len(pdf_files)} PDF files") | |
print(f"β Found {len(txt_files)} text files") | |
if len(pdf_files) == 0 and len(txt_files) == 0: | |
print("β No medical documents found") | |
return False | |
# Show sample files | |
for i, pdf_file in enumerate(pdf_files[:3]): | |
size_mb = pdf_file.stat().st_size / (1024 * 1024) | |
print(f" π {pdf_file.name} ({size_mb:.1f} MB)") | |
return True | |
def test_embedding_model(): | |
"""Test embedding model loading and functionality""" | |
print("\nπ§ Testing embedding model...") | |
try: | |
from llama_index.embeddings.huggingface import HuggingFaceEmbedding | |
# Test higher-dimensional model | |
print("Loading all-mpnet-base-v2 (768-dim)...") | |
embedding_model = HuggingFaceEmbedding( | |
model_name="sentence-transformers/all-mpnet-base-v2", | |
device='cpu', | |
embed_batch_size=2 | |
) | |
# Test embedding generation | |
test_text = "How to treat burns with limited water supply?" | |
start_time = time.time() | |
embedding = embedding_model.get_text_embedding(test_text) | |
embedding_time = time.time() - start_time | |
print(f"β Embedding dimension: {len(embedding)}") | |
print(f"β Embedding time: {embedding_time:.2f}s") | |
print(f"β Sample embedding values: {embedding[:5]}") | |
return True, embedding_model | |
except Exception as e: | |
print(f"β Embedding model error: {e}") | |
traceback.print_exc() | |
return False, None | |
def test_faiss_indexing(): | |
"""Test FAISS indexing functionality""" | |
print("\nπ Testing FAISS indexing...") | |
try: | |
import faiss | |
import numpy as np | |
# Test different index types | |
dimension = 768 | |
# Test flat index | |
flat_index = faiss.IndexFlatL2(dimension) | |
print(f"β Created IndexFlatL2 (dimension: {dimension})") | |
# Test IVF index | |
nlist = 10 # Small for testing | |
quantizer = faiss.IndexFlatL2(dimension) | |
ivf_index = faiss.IndexIVFFlat(quantizer, dimension, nlist) | |
print(f"β Created IndexIVFFlat (clusters: {nlist})") | |
# Test with sample data | |
sample_vectors = np.random.random((50, dimension)).astype('float32') | |
# Train IVF index | |
ivf_index.train(sample_vectors) | |
print("β IVF index training completed") | |
# Add vectors | |
flat_index.add(sample_vectors) | |
ivf_index.add(sample_vectors) | |
print(f"β Added {len(sample_vectors)} vectors to indices") | |
# Test search | |
query_vector = np.random.random((1, dimension)).astype('float32') | |
start_time = time.time() | |
flat_distances, flat_indices = flat_index.search(query_vector, 5) | |
flat_time = time.time() - start_time | |
start_time = time.time() | |
ivf_distances, ivf_indices = ivf_index.search(query_vector, 5) | |
ivf_time = time.time() - start_time | |
print(f"β Flat search time: {flat_time:.4f}s") | |
print(f"β IVF search time: {ivf_time:.4f}s") | |
print(f"β Speed improvement: {flat_time/ivf_time:.2f}x") | |
return True | |
except Exception as e: | |
print(f"β FAISS indexing error: {e}") | |
traceback.print_exc() | |
return False | |
def test_knowledge_base(): | |
"""Test knowledge base initialization and search""" | |
print("\nπ Testing knowledge base...") | |
try: | |
# Import the enhanced system | |
sys.path.append('.') | |
from enhanced_gaza_rag_app import EnhancedGazaKnowledgeBase | |
# Initialize knowledge base | |
print("Initializing knowledge base...") | |
kb = EnhancedGazaKnowledgeBase(data_dir="./data") | |
start_time = time.time() | |
kb.initialize() | |
init_time = time.time() - start_time | |
print(f"β Knowledge base initialized in {init_time:.2f}s") | |
print(f"β Chunks created: {len(kb.chunk_metadata)}") | |
# Test search functionality | |
test_queries = [ | |
"How to treat burns?", | |
"Managing bleeding wounds", | |
"Signs of infection", | |
"Emergency care for children" | |
] | |
for query in test_queries: | |
start_time = time.time() | |
results = kb.search(query, k=3) | |
search_time = time.time() - start_time | |
print(f"β Query: '{query}' -> {len(results)} results in {search_time:.3f}s") | |
if results: | |
best_result = results[0] | |
print(f" π Best match: {best_result.get('source', 'unknown')}") | |
print(f" π― Score: {best_result.get('score', 0):.3f}") | |
print(f" π₯ Priority: {best_result.get('medical_priority', 'general')}") | |
return True, kb | |
except Exception as e: | |
print(f"β Knowledge base error: {e}") | |
traceback.print_exc() | |
return False, None | |
def test_llm_loading(): | |
"""Test LLM loading and inference""" | |
print("\nπ€ Testing LLM loading...") | |
try: | |
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, pipeline | |
import torch | |
model_name = "microsoft/Phi-3-mini-4k-instruct" | |
print(f"Loading {model_name}...") | |
# Test quantization config | |
quantization_config = BitsAndBytesConfig( | |
load_in_4bit=True, | |
bnb_4bit_compute_dtype=torch.float16, | |
bnb_4bit_use_double_quant=True, | |
bnb_4bit_quant_type="nf4" | |
) | |
start_time = time.time() | |
tokenizer = AutoTokenizer.from_pretrained( | |
model_name, | |
trust_remote_code=True | |
) | |
if tokenizer.pad_token is None: | |
tokenizer.pad_token = tokenizer.eos_token | |
model = AutoModelForCausalLM.from_pretrained( | |
model_name, | |
quantization_config=quantization_config, | |
device_map="auto", | |
trust_remote_code=True, | |
torch_dtype=torch.float16, | |
low_cpu_mem_usage=True | |
) | |
loading_time = time.time() - start_time | |
print(f"β Model loaded in {loading_time:.2f}s") | |
# Test pipeline creation | |
generation_pipeline = pipeline( | |
"text-generation", | |
model=model, | |
tokenizer=tokenizer, | |
device_map="auto", | |
torch_dtype=torch.float16, | |
return_full_text=False | |
) | |
print("β Generation pipeline created") | |
# Test inference | |
test_prompt = "How to treat a burn injury: " | |
start_time = time.time() | |
response = generation_pipeline( | |
test_prompt, | |
max_new_tokens=50, | |
temperature=0.2, | |
do_sample=True, | |
pad_token_id=tokenizer.eos_token_id | |
) | |
inference_time = time.time() - start_time | |
if response and len(response) > 0: | |
generated_text = response[0]['generated_text'] | |
print(f"β Inference completed in {inference_time:.2f}s") | |
print(f"β Generated text: {generated_text[:100]}...") | |
else: | |
print("β No response generated") | |
return False | |
return True | |
except Exception as e: | |
print(f"β LLM loading error: {e}") | |
traceback.print_exc() | |
return False | |
def test_full_system(): | |
"""Test the complete enhanced system""" | |
print("\nπ Testing complete enhanced system...") | |
try: | |
# Import the enhanced system | |
from enhanced_gaza_rag_app import initialize_enhanced_system, process_medical_query_with_progress | |
print("Initializing complete system...") | |
start_time = time.time() | |
system = initialize_enhanced_system() | |
init_time = time.time() - start_time | |
print(f"β Complete system initialized in {init_time:.2f}s") | |
# Test queries | |
test_queries = [ | |
"How to treat severe burns when water is limited?", | |
"Managing gunshot wounds with basic supplies", | |
"Signs of wound infection to watch for" | |
] | |
for query in test_queries: | |
print(f"\nπ Testing query: '{query}'") | |
start_time = time.time() | |
response, metadata, status = process_medical_query_with_progress(query) | |
query_time = time.time() - start_time | |
print(f"β Query processed in {query_time:.2f}s") | |
print(f"π Response length: {len(response)} characters") | |
print(f"π Metadata: {metadata}") | |
print(f"π‘οΈ Status: {status}") | |
# Check response quality | |
if len(response) > 50 and "error" not in response.lower(): | |
print("β Response quality: Good") | |
else: | |
print("β οΈ Response quality: Needs improvement") | |
return True | |
except Exception as e: | |
print(f"β Full system test error: {e}") | |
traceback.print_exc() | |
return False | |
def test_ui_components(): | |
"""Test UI components and interface""" | |
print("\nπ¨ Testing UI components...") | |
try: | |
from enhanced_ui_gaza_rag_app import create_advanced_gradio_interface | |
print("Creating advanced Gradio interface...") | |
start_time = time.time() | |
interface = create_advanced_gradio_interface() | |
ui_time = time.time() - start_time | |
print(f"β UI created in {ui_time:.2f}s") | |
print("β Advanced CSS styling applied") | |
print("β Progress indicators configured") | |
print("β Gaza-specific theming applied") | |
print("β Interactive elements configured") | |
return True | |
except Exception as e: | |
print(f"β UI components error: {e}") | |
traceback.print_exc() | |
return False | |
def run_performance_benchmark(): | |
"""Run performance benchmarks""" | |
print("\nβ‘ Running performance benchmarks...") | |
try: | |
from enhanced_gaza_rag_app import initialize_enhanced_system | |
system = initialize_enhanced_system() | |
# Benchmark queries | |
benchmark_queries = [ | |
"How to treat burns?", | |
"Managing bleeding wounds", | |
"Signs of infection", | |
"Emergency care procedures", | |
"Trauma management protocols" | |
] | |
total_time = 0 | |
successful_queries = 0 | |
for i, query in enumerate(benchmark_queries): | |
try: | |
start_time = time.time() | |
result = system.generate_response(query) | |
query_time = time.time() - start_time | |
total_time += query_time | |
successful_queries += 1 | |
print(f"β Query {i+1}: {query_time:.2f}s") | |
except Exception as e: | |
print(f"β Query {i+1} failed: {e}") | |
if successful_queries > 0: | |
avg_time = total_time / successful_queries | |
print(f"\nπ Performance Summary:") | |
print(f" Average query time: {avg_time:.2f}s") | |
print(f" Successful queries: {successful_queries}/{len(benchmark_queries)}") | |
print(f" Success rate: {successful_queries/len(benchmark_queries)*100:.1f}%") | |
return True | |
except Exception as e: | |
print(f"β Performance benchmark error: {e}") | |
traceback.print_exc() | |
return False | |
def main(): | |
"""Run comprehensive test suite""" | |
print("π§ͺ Enhanced Gaza First Aid RAG Assistant - Comprehensive Test Suite") | |
print("=" * 70) | |
test_results = {} | |
# Run all tests | |
tests = [ | |
("Import Dependencies", test_imports), | |
("Data Availability", test_data_availability), | |
("Embedding Model", lambda: test_embedding_model()[0]), | |
("FAISS Indexing", test_faiss_indexing), | |
("Knowledge Base", lambda: test_knowledge_base()[0]), | |
("LLM Loading", test_llm_loading), | |
("Full System", test_full_system), | |
("UI Components", test_ui_components), | |
("Performance Benchmark", run_performance_benchmark) | |
] | |
passed_tests = 0 | |
total_tests = len(tests) | |
for test_name, test_func in tests: | |
print(f"\n{'='*50}") | |
print(f"π§ͺ Running: {test_name}") | |
print(f"{'='*50}") | |
try: | |
result = test_func() | |
test_results[test_name] = result | |
if result: | |
passed_tests += 1 | |
print(f"β {test_name}: PASSED") | |
else: | |
print(f"β {test_name}: FAILED") | |
except Exception as e: | |
test_results[test_name] = False | |
print(f"β {test_name}: ERROR - {e}") | |
# Final summary | |
print(f"\n{'='*70}") | |
print("π TEST SUMMARY") | |
print(f"{'='*70}") | |
for test_name, result in test_results.items(): | |
status = "β PASSED" if result else "β FAILED" | |
print(f"{test_name:.<40} {status}") | |
print(f"\nOverall Results: {passed_tests}/{total_tests} tests passed") | |
print(f"Success Rate: {passed_tests/total_tests*100:.1f}%") | |
if passed_tests == total_tests: | |
print("\nπ ALL TESTS PASSED! Enhanced system is ready for deployment.") | |
elif passed_tests >= total_tests * 0.8: | |
print("\nβ οΈ Most tests passed. System is functional with minor issues.") | |
else: | |
print("\nπ¨ Multiple test failures. System needs attention before deployment.") | |
return passed_tests == total_tests | |
if __name__ == "__main__": | |
success = main() | |
sys.exit(0 if success else 1) | |