medical / test.py
Dama03's picture
gg
feb5633
#!/usr/bin/env python3
"""
FastAPI Medical AI - Testing Script
Test all endpoints for backend integration
"""
import requests
import json
import time
import os
from typing import Dict, Any
# Configuration
API_BASE_URL = "http://localhost:8000" # Change this to your deployed URL
TEST_AUDIO_PATH = "test_audio.wav" # Optional audio file for testing
class MedicalAITester:
def __init__(self, base_url: str = API_BASE_URL):
self.base_url = base_url.rstrip('/')
self.session = requests.Session()
self.results = []
def log_result(self, test_name: str, success: bool, details: Dict[str, Any]):
"""Log test result"""
status = "βœ… PASS" if success else "❌ FAIL"
print(f"{status} {test_name}")
if details:
print(f" Details: {details}")
self.results.append({
'test': test_name,
'success': success,
'details': details
})
def test_root_endpoint(self):
"""Test root endpoint"""
try:
response = self.session.get(f"{self.base_url}/")
success = response.status_code == 200
details = response.json() if success else {"error": response.text}
self.log_result("Root Endpoint", success, details)
except Exception as e:
self.log_result("Root Endpoint", False, {"error": str(e)})
def test_health_check(self):
"""Test health check endpoint"""
try:
response = self.session.get(f"{self.base_url}/health")
success = response.status_code == 200
if success:
data = response.json()
details = {
"status": data.get("status"),
"models_loaded": data.get("models_loaded"),
"audio_available": data.get("audio_available")
}
else:
details = {"error": response.text}
self.log_result("Health Check", success, details)
except Exception as e:
self.log_result("Health Check", False, {"error": str(e)})
def test_medical_consultation(self):
"""Test text-based medical consultation"""
try:
payload = {
"question": "What are the symptoms of malaria and how is it treated?",
"language": "en"
}
response = self.session.post(
f"{self.base_url}/medical/ask",
json=payload,
headers={"Content-Type": "application/json"}
)
success = response.status_code == 200
if success:
data = response.json()
details = {
"response_length": len(data.get("response", "")),
"detected_language": data.get("detected_language"),
"processing_time": data.get("processing_time"),
"conversation_id": data.get("conversation_id"),
"contexts_used": len(data.get("context_used", []))
}
else:
details = {"error": response.text, "status_code": response.status_code}
self.log_result("Medical Consultation (Text)", success, details)
# Return conversation_id for feedback test
return data.get("conversation_id") if success else None
except Exception as e:
self.log_result("Medical Consultation (Text)", False, {"error": str(e)})
return None
def test_medical_consultation_french(self):
"""Test French medical consultation"""
try:
payload = {
"question": "Quels sont les symptômes du diabète?",
"language": "fr"
}
response = self.session.post(
f"{self.base_url}/medical/ask",
json=payload
)
success = response.status_code == 200
if success:
data = response.json()
details = {
"response_length": len(data.get("response", "")),
"detected_language": data.get("detected_language"),
"processing_time": data.get("processing_time")
}
else:
details = {"error": response.text}
self.log_result("Medical Consultation (French)", success, details)
except Exception as e:
self.log_result("Medical Consultation (French)", False, {"error": str(e)})
def test_audio_consultation(self):
"""Test audio-based medical consultation (if audio file exists)"""
if not os.path.exists(TEST_AUDIO_PATH):
self.log_result("Audio Consultation", False, {"error": "No test audio file found"})
return
try:
with open(TEST_AUDIO_PATH, 'rb') as audio_file:
files = {'file': ('test_audio.wav', audio_file, 'audio/wav')}
response = self.session.post(
f"{self.base_url}/medical/audio",
files=files
)
success = response.status_code == 200
if success:
data = response.json()
details = {
"transcription": data.get("transcription", "")[:50] + "...",
"response_length": len(data.get("response", "")),
"detected_language": data.get("detected_language"),
"processing_time": data.get("processing_time"),
"audio_duration": data.get("audio_duration")
}
else:
details = {"error": response.text}
self.log_result("Audio Consultation", success, details)
except Exception as e:
self.log_result("Audio Consultation", False, {"error": str(e)})
def test_feedback_submission(self, conversation_id: str = None):
"""Test feedback submission"""
if not conversation_id:
conversation_id = "test_conv_123"
try:
payload = {
"conversation_id": conversation_id,
"rating": 5,
"feedback": "Great medical advice, very helpful and accurate"
}
response = self.session.post(
f"{self.base_url}/feedback",
json=payload
)
success = response.status_code == 200
if success:
data = response.json()
details = {
"message": data.get("message"),
"feedback_id": data.get("feedback_id")
}
else:
details = {"error": response.text}
self.log_result("Feedback Submission", success, details)
except Exception as e:
self.log_result("Feedback Submission", False, {"error": str(e)})
def test_medical_specialties(self):
"""Test medical specialties endpoint"""
try:
response = self.session.get(f"{self.base_url}/medical/specialties")
success = response.status_code == 200
if success:
data = response.json()
details = {
"specialties_count": len(data.get("specialties", [])),
"languages_supported": data.get("languages_supported", [])
}
else:
details = {"error": response.text}
self.log_result("Medical Specialties", success, details)
except Exception as e:
self.log_result("Medical Specialties", False, {"error": str(e)})
def test_invalid_endpoints(self):
"""Test invalid endpoints for proper error handling"""
try:
response = self.session.get(f"{self.base_url}/invalid/endpoint")
success = response.status_code == 404
details = {"status_code": response.status_code}
self.log_result("404 Error Handling", success, details)
except Exception as e:
self.log_result("404 Error Handling", False, {"error": str(e)})
def test_validation_errors(self):
"""Test validation error handling"""
try:
# Send invalid data to medical consultation
payload = {
"question": "", # Empty question should fail validation
"language": "invalid_lang" # Invalid language
}
response = self.session.post(
f"{self.base_url}/medical/ask",
json=payload
)
success = response.status_code == 422 # Validation error
details = {"status_code": response.status_code}
self.log_result("Validation Error Handling", success, details)
except Exception as e:
self.log_result("Validation Error Handling", False, {"error": str(e)})
def test_openapi_docs(self):
"""Test OpenAPI documentation endpoints"""
endpoints_to_test = [
("/docs", "Swagger UI"),
("/redoc", "ReDoc UI"),
("/openapi.json", "OpenAPI Schema")
]
for endpoint, name in endpoints_to_test:
try:
response = self.session.get(f"{self.base_url}{endpoint}")
success = response.status_code == 200
details = {"status_code": response.status_code}
self.log_result(f"Documentation - {name}", success, details)
except Exception as e:
self.log_result(f"Documentation - {name}", False, {"error": str(e)})
def run_all_tests(self):
"""Run all tests"""
print("πŸ§ͺ Starting FastAPI Medical AI Tests")
print("=" * 50)
# Basic connectivity tests
self.test_root_endpoint()
self.test_health_check()
# Core functionality tests
conversation_id = self.test_medical_consultation()
self.test_medical_consultation_french()
self.test_audio_consultation()
# Supporting features
self.test_feedback_submission(conversation_id)
self.test_medical_specialties()
# Error handling tests
self.test_invalid_endpoints()
self.test_validation_errors()
# Documentation tests
self.test_openapi_docs()
# Print summary
self.print_summary()
def print_summary(self):
"""Print test summary"""
print("\n" + "=" * 50)
print("πŸ“Š TEST SUMMARY")
print("=" * 50)
passed = sum(1 for r in self.results if r['success'])
total = len(self.results)
print(f"βœ… Passed: {passed}/{total}")
print(f"❌ Failed: {total - passed}/{total}")
print(f"πŸ“ˆ Success Rate: {(passed/total)*100:.1f}%")
if total - passed > 0:
print("\nπŸ” Failed Tests:")
for result in self.results:
if not result['success']:
print(f" ❌ {result['test']}: {result['details'].get('error', 'Unknown error')}")
print("\nπŸ’‘ Next Steps:")
if passed == total:
print(" πŸŽ‰ All tests passed! Your FastAPI is ready for production.")
print(" πŸ“š View API docs at: /docs")
print(" πŸ”„ Alternative docs at: /redoc")
else:
print(" πŸ”§ Fix failing tests before deployment")
print(" πŸ“‹ Check logs for detailed error information")
print(" ⚑ Ensure all ML models are loaded properly")
def main():
"""Main test execution"""
print("🩺 FastAPI Medical AI - Test Suite")
print("πŸ”— Testing API endpoints for backend integration")
print()
# You can change the base URL here for testing deployed versions
base_url = input(f"Enter API base URL (default: {API_BASE_URL}): ").strip()
if not base_url:
base_url = API_BASE_URL
tester = MedicalAITester(base_url)
tester.run_all_tests()
def quick_test():
"""Quick test for essential endpoints only"""
print("⚑ Quick FastAPI Test")
print("=" * 30)
tester = MedicalAITester()
# Test only core functionality
tester.test_root_endpoint()
tester.test_health_check()
tester.test_medical_consultation()
# Print quick summary
passed = sum(1 for r in tester.results if r['success'])
total = len(tester.results)
print(f"\n⚑ Quick Test: {passed}/{total} passed")
if __name__ == "__main__":
import sys
if len(sys.argv) > 1 and sys.argv[1] == "quick":
quick_test()
else:
main()