|
|
|
""" |
|
FastAPI Medical AI - Testing Script |
|
Test all endpoints for backend integration |
|
""" |
|
|
|
import requests |
|
import json |
|
import time |
|
import os |
|
from typing import Dict, Any |
|
|
|
|
|
API_BASE_URL = "http://localhost:8000" |
|
TEST_AUDIO_PATH = "test_audio.wav" |
|
|
|
class MedicalAITester: |
|
def __init__(self, base_url: str = API_BASE_URL): |
|
self.base_url = base_url.rstrip('/') |
|
self.session = requests.Session() |
|
self.results = [] |
|
|
|
def log_result(self, test_name: str, success: bool, details: Dict[str, Any]): |
|
"""Log test result""" |
|
status = "β
PASS" if success else "β FAIL" |
|
print(f"{status} {test_name}") |
|
if details: |
|
print(f" Details: {details}") |
|
|
|
self.results.append({ |
|
'test': test_name, |
|
'success': success, |
|
'details': details |
|
}) |
|
|
|
def test_root_endpoint(self): |
|
"""Test root endpoint""" |
|
try: |
|
response = self.session.get(f"{self.base_url}/") |
|
success = response.status_code == 200 |
|
details = response.json() if success else {"error": response.text} |
|
self.log_result("Root Endpoint", success, details) |
|
except Exception as e: |
|
self.log_result("Root Endpoint", False, {"error": str(e)}) |
|
|
|
def test_health_check(self): |
|
"""Test health check endpoint""" |
|
try: |
|
response = self.session.get(f"{self.base_url}/health") |
|
success = response.status_code == 200 |
|
if success: |
|
data = response.json() |
|
details = { |
|
"status": data.get("status"), |
|
"models_loaded": data.get("models_loaded"), |
|
"audio_available": data.get("audio_available") |
|
} |
|
else: |
|
details = {"error": response.text} |
|
self.log_result("Health Check", success, details) |
|
except Exception as e: |
|
self.log_result("Health Check", False, {"error": str(e)}) |
|
|
|
def test_medical_consultation(self): |
|
"""Test text-based medical consultation""" |
|
try: |
|
payload = { |
|
"question": "What are the symptoms of malaria and how is it treated?", |
|
"language": "en" |
|
} |
|
|
|
response = self.session.post( |
|
f"{self.base_url}/medical/ask", |
|
json=payload, |
|
headers={"Content-Type": "application/json"} |
|
) |
|
|
|
success = response.status_code == 200 |
|
if success: |
|
data = response.json() |
|
details = { |
|
"response_length": len(data.get("response", "")), |
|
"detected_language": data.get("detected_language"), |
|
"processing_time": data.get("processing_time"), |
|
"conversation_id": data.get("conversation_id"), |
|
"contexts_used": len(data.get("context_used", [])) |
|
} |
|
else: |
|
details = {"error": response.text, "status_code": response.status_code} |
|
|
|
self.log_result("Medical Consultation (Text)", success, details) |
|
|
|
|
|
return data.get("conversation_id") if success else None |
|
|
|
except Exception as e: |
|
self.log_result("Medical Consultation (Text)", False, {"error": str(e)}) |
|
return None |
|
|
|
def test_medical_consultation_french(self): |
|
"""Test French medical consultation""" |
|
try: |
|
payload = { |
|
"question": "Quels sont les symptômes du diabète?", |
|
"language": "fr" |
|
} |
|
|
|
response = self.session.post( |
|
f"{self.base_url}/medical/ask", |
|
json=payload |
|
) |
|
|
|
success = response.status_code == 200 |
|
if success: |
|
data = response.json() |
|
details = { |
|
"response_length": len(data.get("response", "")), |
|
"detected_language": data.get("detected_language"), |
|
"processing_time": data.get("processing_time") |
|
} |
|
else: |
|
details = {"error": response.text} |
|
|
|
self.log_result("Medical Consultation (French)", success, details) |
|
|
|
except Exception as e: |
|
self.log_result("Medical Consultation (French)", False, {"error": str(e)}) |
|
|
|
def test_audio_consultation(self): |
|
"""Test audio-based medical consultation (if audio file exists)""" |
|
if not os.path.exists(TEST_AUDIO_PATH): |
|
self.log_result("Audio Consultation", False, {"error": "No test audio file found"}) |
|
return |
|
|
|
try: |
|
with open(TEST_AUDIO_PATH, 'rb') as audio_file: |
|
files = {'file': ('test_audio.wav', audio_file, 'audio/wav')} |
|
|
|
response = self.session.post( |
|
f"{self.base_url}/medical/audio", |
|
files=files |
|
) |
|
|
|
success = response.status_code == 200 |
|
if success: |
|
data = response.json() |
|
details = { |
|
"transcription": data.get("transcription", "")[:50] + "...", |
|
"response_length": len(data.get("response", "")), |
|
"detected_language": data.get("detected_language"), |
|
"processing_time": data.get("processing_time"), |
|
"audio_duration": data.get("audio_duration") |
|
} |
|
else: |
|
details = {"error": response.text} |
|
|
|
self.log_result("Audio Consultation", success, details) |
|
|
|
except Exception as e: |
|
self.log_result("Audio Consultation", False, {"error": str(e)}) |
|
|
|
def test_feedback_submission(self, conversation_id: str = None): |
|
"""Test feedback submission""" |
|
if not conversation_id: |
|
conversation_id = "test_conv_123" |
|
|
|
try: |
|
payload = { |
|
"conversation_id": conversation_id, |
|
"rating": 5, |
|
"feedback": "Great medical advice, very helpful and accurate" |
|
} |
|
|
|
response = self.session.post( |
|
f"{self.base_url}/feedback", |
|
json=payload |
|
) |
|
|
|
success = response.status_code == 200 |
|
if success: |
|
data = response.json() |
|
details = { |
|
"message": data.get("message"), |
|
"feedback_id": data.get("feedback_id") |
|
} |
|
else: |
|
details = {"error": response.text} |
|
|
|
self.log_result("Feedback Submission", success, details) |
|
|
|
except Exception as e: |
|
self.log_result("Feedback Submission", False, {"error": str(e)}) |
|
|
|
def test_medical_specialties(self): |
|
"""Test medical specialties endpoint""" |
|
try: |
|
response = self.session.get(f"{self.base_url}/medical/specialties") |
|
success = response.status_code == 200 |
|
if success: |
|
data = response.json() |
|
details = { |
|
"specialties_count": len(data.get("specialties", [])), |
|
"languages_supported": data.get("languages_supported", []) |
|
} |
|
else: |
|
details = {"error": response.text} |
|
|
|
self.log_result("Medical Specialties", success, details) |
|
|
|
except Exception as e: |
|
self.log_result("Medical Specialties", False, {"error": str(e)}) |
|
|
|
def test_invalid_endpoints(self): |
|
"""Test invalid endpoints for proper error handling""" |
|
try: |
|
response = self.session.get(f"{self.base_url}/invalid/endpoint") |
|
success = response.status_code == 404 |
|
details = {"status_code": response.status_code} |
|
self.log_result("404 Error Handling", success, details) |
|
except Exception as e: |
|
self.log_result("404 Error Handling", False, {"error": str(e)}) |
|
|
|
def test_validation_errors(self): |
|
"""Test validation error handling""" |
|
try: |
|
|
|
payload = { |
|
"question": "", |
|
"language": "invalid_lang" |
|
} |
|
|
|
response = self.session.post( |
|
f"{self.base_url}/medical/ask", |
|
json=payload |
|
) |
|
|
|
success = response.status_code == 422 |
|
details = {"status_code": response.status_code} |
|
self.log_result("Validation Error Handling", success, details) |
|
|
|
except Exception as e: |
|
self.log_result("Validation Error Handling", False, {"error": str(e)}) |
|
|
|
def test_openapi_docs(self): |
|
"""Test OpenAPI documentation endpoints""" |
|
endpoints_to_test = [ |
|
("/docs", "Swagger UI"), |
|
("/redoc", "ReDoc UI"), |
|
("/openapi.json", "OpenAPI Schema") |
|
] |
|
|
|
for endpoint, name in endpoints_to_test: |
|
try: |
|
response = self.session.get(f"{self.base_url}{endpoint}") |
|
success = response.status_code == 200 |
|
details = {"status_code": response.status_code} |
|
self.log_result(f"Documentation - {name}", success, details) |
|
except Exception as e: |
|
self.log_result(f"Documentation - {name}", False, {"error": str(e)}) |
|
|
|
def run_all_tests(self): |
|
"""Run all tests""" |
|
print("π§ͺ Starting FastAPI Medical AI Tests") |
|
print("=" * 50) |
|
|
|
|
|
self.test_root_endpoint() |
|
self.test_health_check() |
|
|
|
|
|
conversation_id = self.test_medical_consultation() |
|
self.test_medical_consultation_french() |
|
self.test_audio_consultation() |
|
|
|
|
|
self.test_feedback_submission(conversation_id) |
|
self.test_medical_specialties() |
|
|
|
|
|
self.test_invalid_endpoints() |
|
self.test_validation_errors() |
|
|
|
|
|
self.test_openapi_docs() |
|
|
|
|
|
self.print_summary() |
|
|
|
def print_summary(self): |
|
"""Print test summary""" |
|
print("\n" + "=" * 50) |
|
print("π TEST SUMMARY") |
|
print("=" * 50) |
|
|
|
passed = sum(1 for r in self.results if r['success']) |
|
total = len(self.results) |
|
|
|
print(f"β
Passed: {passed}/{total}") |
|
print(f"β Failed: {total - passed}/{total}") |
|
print(f"π Success Rate: {(passed/total)*100:.1f}%") |
|
|
|
if total - passed > 0: |
|
print("\nπ Failed Tests:") |
|
for result in self.results: |
|
if not result['success']: |
|
print(f" β {result['test']}: {result['details'].get('error', 'Unknown error')}") |
|
|
|
print("\nπ‘ Next Steps:") |
|
if passed == total: |
|
print(" π All tests passed! Your FastAPI is ready for production.") |
|
print(" π View API docs at: /docs") |
|
print(" π Alternative docs at: /redoc") |
|
else: |
|
print(" π§ Fix failing tests before deployment") |
|
print(" π Check logs for detailed error information") |
|
print(" β‘ Ensure all ML models are loaded properly") |
|
|
|
def main(): |
|
"""Main test execution""" |
|
print("π©Ί FastAPI Medical AI - Test Suite") |
|
print("π Testing API endpoints for backend integration") |
|
print() |
|
|
|
|
|
base_url = input(f"Enter API base URL (default: {API_BASE_URL}): ").strip() |
|
if not base_url: |
|
base_url = API_BASE_URL |
|
|
|
tester = MedicalAITester(base_url) |
|
tester.run_all_tests() |
|
|
|
def quick_test(): |
|
"""Quick test for essential endpoints only""" |
|
print("β‘ Quick FastAPI Test") |
|
print("=" * 30) |
|
|
|
tester = MedicalAITester() |
|
|
|
|
|
tester.test_root_endpoint() |
|
tester.test_health_check() |
|
tester.test_medical_consultation() |
|
|
|
|
|
passed = sum(1 for r in tester.results if r['success']) |
|
total = len(tester.results) |
|
print(f"\nβ‘ Quick Test: {passed}/{total} passed") |
|
|
|
if __name__ == "__main__": |
|
import sys |
|
|
|
if len(sys.argv) > 1 and sys.argv[1] == "quick": |
|
quick_test() |
|
else: |
|
main() |