Spaces:
Paused
Paused
| #!/usr/bin/env python3 | |
| """ | |
| Test script to validate summary length consistency across multiple requests. | |
| This script tests the SummarizerAgent with various input texts to ensure | |
| that summary lengths don't degrade over multiple requests. | |
| """ | |
| import sys | |
| import os | |
| import logging | |
| from unittest.mock import Mock | |
| # Add the project root to the Python path | |
| sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) | |
| from ai_med_extract.agents.summarizer import SummarizerAgent | |
| # Configure logging | |
| logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') | |
| def create_mock_model_loader(): | |
| """Create a mock model loader for testing""" | |
| mock_loader = Mock() | |
| # Mock model that returns consistent summaries | |
| mock_model = Mock() | |
| def mock_generate(text, **kwargs): | |
| # Simulate a model that generates summaries based on input length | |
| word_count = len(text.split()) | |
| summary_length = min(kwargs.get('max_length', 1024), max(kwargs.get('min_length', 30), word_count // 2)) | |
| # Generate a mock summary with the calculated length | |
| summary_words = ["summary"] * summary_length | |
| return [{'summary_text': ' '.join(summary_words)}] | |
| mock_model.side_effect = mock_generate | |
| mock_loader.load.return_value = mock_model | |
| return mock_loader | |
| def test_summary_consistency(): | |
| """Test that summary lengths remain consistent across multiple requests""" | |
| print("Testing Summary Consistency Across Multiple Requests") | |
| print("=" * 60) | |
| # Create mock model loader | |
| mock_loader = create_mock_model_loader() | |
| summarizer = SummarizerAgent(mock_loader) | |
| # Test with different input texts | |
| test_texts = [ | |
| "Patient presents with chest pain and shortness of breath. " * 10, | |
| "Medical history includes hypertension, diabetes, and hyperlipidemia. " * 15, | |
| "Laboratory results show elevated cholesterol levels and normal blood glucose. " * 20, | |
| "Physical examination reveals normal heart sounds and clear lung fields. " * 25 | |
| ] | |
| results = [] | |
| for i, text in enumerate(test_texts, 1): | |
| print(f"\nTest {i}: Input text length = {len(text.split())} words") | |
| # Generate multiple summaries with the same text | |
| summary_lengths = [] | |
| for request_num in range(1, 6): # 5 requests per text | |
| summary = summarizer.generate_summary(text) | |
| word_count = len(summary.split()) | |
| summary_lengths.append(word_count) | |
| print(f" Request {request_num}: {word_count} words") | |
| # Check consistency (all summaries should be within 10% of each other) | |
| avg_length = sum(summary_lengths) / len(summary_lengths) | |
| max_variation = max(abs(length - avg_length) for length in summary_lengths) | |
| variation_percent = (max_variation / avg_length) * 100 if avg_length > 0 else 0 | |
| consistent = variation_percent <= 10 # Allow 10% variation | |
| status = "PASS" if consistent else "FAIL" | |
| results.append({ | |
| 'test': i, | |
| 'input_words': len(text.split()), | |
| 'summary_lengths': summary_lengths, | |
| 'avg_length': avg_length, | |
| 'max_variation': max_variation, | |
| 'variation_percent': variation_percent, | |
| 'consistent': consistent, | |
| 'status': status | |
| }) | |
| print(f" Consistency: {status} (Variation: {variation_percent:.1f}%)") | |
| # Print summary of results | |
| print("\n" + "=" * 60) | |
| print("SUMMARY OF RESULTS") | |
| print("=" * 60) | |
| all_passed = all(result['consistent'] for result in results) | |
| for result in results: | |
| print(f"Test {result['test']}: {result['status']}") | |
| print(f" Input: {result['input_words']} words") | |
| print(f" Summaries: {result['summary_lengths']}") | |
| print(f" Avg: {result['avg_length']:.1f}, Max variation: {result['max_variation']:.1f}") | |
| print(f" Variation: {result['variation_percent']:.1f}%") | |
| print() | |
| print(f"OVERALL: {'ALL TESTS PASSED' if all_passed else 'SOME TESTS FAILED'}") | |
| return all_passed | |
| def test_edge_cases(): | |
| """Test edge cases for the summarizer""" | |
| print("\nTesting Edge Cases") | |
| print("=" * 40) | |
| mock_loader = create_mock_model_loader() | |
| summarizer = SummarizerAgent(mock_loader) | |
| # Test with very short text | |
| short_text = "Patient has fever." | |
| summary = summarizer.generate_summary(short_text) | |
| print(f"Short text ('{short_text}'): '{summary}'") | |
| # Test with empty text | |
| empty_text = "" | |
| summary = summarizer.generate_summary(empty_text) | |
| print(f"Empty text: '{summary}'") | |
| # Test with None | |
| summary = summarizer.generate_summary(None) | |
| print(f"None input: '{summary}'") | |
| if __name__ == "__main__": | |
| try: | |
| # Run consistency tests | |
| consistency_passed = test_summary_consistency() | |
| # Run edge case tests | |
| test_edge_cases() | |
| # Exit with appropriate code | |
| sys.exit(0 if consistency_passed else 1) | |
| except Exception as e: | |
| print(f"Error during testing: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| sys.exit(1) | |