HNTAI / test_refactored_system.py
sachinchandrankallar's picture
optimized code
c6f267d
raw
history blame
12.8 kB
#!/usr/bin/env python3
"""
Test script for the refactored HNTAI system
Demonstrates the new unified model manager and dynamic model loading capabilities
"""
import os
import sys
import time
import logging
import requests
import json
# Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
# Set environment variables for testing
os.environ['HF_HOME'] = '/tmp/huggingface'
os.environ['GGUF_N_THREADS'] = '2'
os.environ['GGUF_N_BATCH'] = '64'
def test_model_manager():
"""Test the unified model manager"""
logger.info("Testing Unified Model Manager...")
try:
from ai_med_extract.utils.model_manager import model_manager
# Test 1: Load a transformers model
logger.info("Test 1: Loading Transformers model...")
loader = model_manager.get_model_loader("facebook/bart-base", "text-generation")
result = loader.generate("Hello, how are you?", max_new_tokens=50)
logger.info(f"βœ… Transformers model test passed: {len(result)} characters generated")
# Test 2: Load a GGUF model
logger.info("Test 2: Loading GGUF model...")
try:
gguf_loader = model_manager.get_model_loader(
"microsoft/Phi-3-mini-4k-instruct-gguf",
"gguf"
)
result = gguf_loader.generate("Generate a brief medical summary: Patient has fever and cough.", max_tokens=100)
logger.info(f"βœ… GGUF model test passed: {len(result)} characters generated")
except Exception as e:
logger.warning(f"⚠️ GGUF model test failed (this is expected if model not available): {e}")
# Test 3: Test fallback mechanism
logger.info("Test 3: Testing fallback mechanism...")
try:
fallback_loader = model_manager.get_model_loader("invalid/model", "text-generation")
result = fallback_loader.generate("Test prompt")
logger.info(f"βœ… Fallback mechanism test passed: {len(result)} characters generated")
except Exception as e:
logger.error(f"❌ Fallback mechanism test failed: {e}")
return False
logger.info("πŸŽ‰ All model manager tests passed!")
return True
except Exception as e:
logger.error(f"❌ Model manager test failed: {e}")
return False
def test_patient_summarizer():
"""Test the refactored patient summarizer agent"""
logger.info("Testing Patient Summarizer Agent...")
try:
from ai_med_extract.agents.patient_summary_agent import PatientSummarizerAgent
# Test with different model types
test_cases = [
{
"name": "Transformers Summarization",
"model_name": "Falconsai/medical_summarization",
"model_type": "summarization"
},
{
"name": "GGUF Model",
"model_name": "microsoft/Phi-3-mini-4k-instruct-gguf",
"model_type": "gguf"
}
]
for test_case in test_cases:
logger.info(f"Testing: {test_case['name']}")
try:
agent = PatientSummarizerAgent(
model_name=test_case["model_name"],
model_type=test_case["model_type"]
)
# Test with sample patient data
sample_data = {
"result": {
"patientname": "John Doe",
"patientnumber": "12345",
"agey": "45",
"gender": "Male",
"allergies": ["Penicillin"],
"social_history": "Non-smoker, occasional alcohol",
"past_medical_history": ["Hypertension", "Diabetes"],
"encounters": [
{
"visit_date": "2024-01-15",
"chief_complaint": "Chest pain",
"symptoms": "Sharp chest pain, shortness of breath",
"diagnosis": ["Angina", "Hypertension"],
"dr_notes": "Patient reports chest pain for 2 days",
"vitals": {"BP": "140/90", "HR": "85", "SpO2": "98%"},
"medications": ["Aspirin", "Metoprolol"],
"treatment": "Prescribed medications, follow-up in 1 week"
}
]
}
}
summary = agent.generate_clinical_summary(sample_data)
logger.info(f"βœ… {test_case['name']} test passed: {len(summary)} characters generated")
except Exception as e:
logger.warning(f"⚠️ {test_case['name']} test failed (this may be expected): {e}")
logger.info("πŸŽ‰ Patient summarizer tests completed!")
return True
except Exception as e:
logger.error(f"❌ Patient summarizer test failed: {e}")
return False
def test_model_config():
"""Test the model configuration system"""
logger.info("Testing Model Configuration...")
try:
from ai_med_extract.utils.model_config import (
detect_model_type,
validate_model_config,
get_model_info,
get_default_model
)
# Test model type detection
test_models = [
("facebook/bart-base", "text-generation"),
("Falconsai/medical_summarization", "summarization"),
("microsoft/Phi-3-mini-4k-instruct-gguf", "gguf"),
("model.gguf", "gguf"),
("unknown/model", "text-generation") # Default fallback
]
for model_name, expected_type in test_models:
detected_type = detect_model_type(model_name)
if detected_type == expected_type:
logger.info(f"βœ… Model type detection correct: {model_name} -> {detected_type}")
else:
logger.warning(f"⚠️ Model type detection mismatch: {model_name} -> {detected_type} (expected {expected_type})")
# Test model validation
validation = validate_model_config("microsoft/Phi-3-mini-4k-instruct-gguf", "gguf")
if validation["valid"]:
logger.info("βœ… Model validation test passed")
else:
logger.warning(f"⚠️ Model validation warnings: {validation['warnings']}")
# Test default models
default_summary = get_default_model("summarization")
logger.info(f"βœ… Default summarization model: {default_summary}")
logger.info("πŸŽ‰ Model configuration tests completed!")
return True
except Exception as e:
logger.error(f"❌ Model configuration test failed: {e}")
return False
def test_api_endpoints():
"""Test the new API endpoints (if server is running)"""
logger.info("Testing API Endpoints...")
base_url = "http://localhost:7860" # Adjust if different
try:
# Test health check
response = requests.get(f"{base_url}/api/models/health", timeout=10)
if response.status_code == 200:
health_data = response.json()
logger.info(f"βœ… Health check passed: {health_data.get('status', 'unknown')}")
logger.info(f" Loaded models: {health_data.get('loaded_models_count', 0)}")
if health_data.get('gpu_info', {}).get('available'):
logger.info(f" GPU memory: {health_data['gpu_info']['memory_allocated']}")
else:
logger.warning(f"⚠️ Health check failed with status {response.status_code}")
return False
# Test model info
response = requests.get(f"{base_url}/api/models/info", timeout=10)
if response.status_code == 200:
info_data = response.json()
logger.info(f"βœ… Model info endpoint working: {info_data.get('total_models', 0)} models loaded")
else:
logger.warning(f"⚠️ Model info endpoint failed with status {response.status_code}")
# Test default models
response = requests.get(f"{base_url}/api/models/defaults", timeout=10)
if response.status_code == 200:
defaults_data = response.json()
logger.info(f"βœ… Default models endpoint working: {len(defaults_data.get('default_models', {}))} model types available")
else:
logger.warning(f"⚠️ Default models endpoint failed with status {response.status_code}")
logger.info("πŸŽ‰ API endpoint tests completed!")
return True
except requests.exceptions.ConnectionError:
logger.warning("⚠️ Server not running, skipping API tests")
return True
except Exception as e:
logger.error(f"❌ API endpoint test failed: {e}")
return False
def test_memory_optimization():
"""Test memory optimization features"""
logger.info("Testing Memory Optimization...")
try:
import torch
# Check if we're in Hugging Face Spaces
is_hf_space = os.environ.get('SPACE_ID') is not None
if is_hf_space:
logger.info("πŸ”„ Detected Hugging Face Space - testing memory optimization...")
# Test with smaller models
from ai_med_extract.utils.model_manager import model_manager
loader = model_manager.get_model_loader("facebook/bart-base", "text-generation")
result = loader.generate("Test prompt for memory optimization", max_new_tokens=50)
logger.info(f"βœ… Memory optimization test passed: {len(result)} characters generated")
else:
logger.info("πŸ”„ Local environment detected - memory optimization not applicable")
# Test cache clearing
from ai_med_extract.utils.model_manager import model_manager
model_manager.clear_cache()
logger.info("βœ… Cache clearing test passed")
return True
except Exception as e:
logger.error(f"❌ Memory optimization test failed: {e}")
return False
def main():
"""Main test function"""
logger.info("πŸš€ Starting HNTAI Refactored System Tests...")
logger.info("=" * 60)
test_results = []
# Run all tests
tests = [
("Model Manager", test_model_manager),
("Patient Summarizer", test_patient_summarizer),
("Model Configuration", test_model_config),
("API Endpoints", test_api_endpoints),
("Memory Optimization", test_memory_optimization)
]
for test_name, test_func in tests:
logger.info(f"\nπŸ§ͺ Running {test_name} Test...")
try:
result = test_func()
test_results.append((test_name, result))
except Exception as e:
logger.error(f"❌ {test_name} test crashed: {e}")
test_results.append((test_name, False))
# Summary
logger.info("\n" + "=" * 60)
logger.info("πŸ“Š TEST SUMMARY")
logger.info("=" * 60)
passed = 0
total = len(test_results)
for test_name, result in test_results:
status = "βœ… PASS" if result else "❌ FAIL"
logger.info(f"{test_name}: {status}")
if result:
passed += 1
logger.info(f"\nOverall: {passed}/{total} tests passed")
if passed == total:
logger.info("πŸŽ‰ All tests passed! The refactored system is working correctly.")
logger.info("✨ You can now use any model name and type, including GGUF models!")
else:
logger.warning(f"⚠️ {total - passed} tests failed. Check the logs above for details.")
# Recommendations
logger.info("\nπŸ’‘ RECOMMENDATIONS:")
if passed >= total * 0.8:
logger.info("βœ… System is ready for production use")
logger.info("βœ… GGUF models are supported for patient summaries")
logger.info("βœ… Dynamic model loading is working")
elif passed >= total * 0.6:
logger.info("⚠️ System is mostly working but has some issues")
logger.info("⚠️ Check failed tests and fix issues")
else:
logger.error("❌ System has significant issues")
logger.error("❌ Review and fix failed tests before use")
return passed == total
if __name__ == "__main__":
success = main()
sys.exit(0 if success else 1)