HNTAI / test_summary_consistency.py
sachinchandrankallar's picture
summary consistancy
ef42d64
raw
history blame
5.3 kB
#!/usr/bin/env python3
"""
Test script to validate summary length consistency across multiple requests.
This script tests the SummarizerAgent with various input texts to ensure
that summary lengths don't degrade over multiple requests.
"""
import sys
import os
import logging
from unittest.mock import Mock
# Add the project root to the Python path
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
from ai_med_extract.agents.summarizer import SummarizerAgent
# Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
def create_mock_model_loader():
"""Create a mock model loader for testing"""
mock_loader = Mock()
# Mock model that returns consistent summaries
mock_model = Mock()
def mock_generate(text, **kwargs):
# Simulate a model that generates summaries based on input length
word_count = len(text.split())
summary_length = min(kwargs.get('max_length', 1024), max(kwargs.get('min_length', 30), word_count // 2))
# Generate a mock summary with the calculated length
summary_words = ["summary"] * summary_length
return [{'summary_text': ' '.join(summary_words)}]
mock_model.side_effect = mock_generate
mock_loader.load.return_value = mock_model
return mock_loader
def test_summary_consistency():
"""Test that summary lengths remain consistent across multiple requests"""
print("Testing Summary Consistency Across Multiple Requests")
print("=" * 60)
# Create mock model loader
mock_loader = create_mock_model_loader()
summarizer = SummarizerAgent(mock_loader)
# Test with different input texts
test_texts = [
"Patient presents with chest pain and shortness of breath. " * 10,
"Medical history includes hypertension, diabetes, and hyperlipidemia. " * 15,
"Laboratory results show elevated cholesterol levels and normal blood glucose. " * 20,
"Physical examination reveals normal heart sounds and clear lung fields. " * 25
]
results = []
for i, text in enumerate(test_texts, 1):
print(f"\nTest {i}: Input text length = {len(text.split())} words")
# Generate multiple summaries with the same text
summary_lengths = []
for request_num in range(1, 6): # 5 requests per text
summary = summarizer.generate_summary(text)
word_count = len(summary.split())
summary_lengths.append(word_count)
print(f" Request {request_num}: {word_count} words")
# Check consistency (all summaries should be within 10% of each other)
avg_length = sum(summary_lengths) / len(summary_lengths)
max_variation = max(abs(length - avg_length) for length in summary_lengths)
variation_percent = (max_variation / avg_length) * 100 if avg_length > 0 else 0
consistent = variation_percent <= 10 # Allow 10% variation
status = "PASS" if consistent else "FAIL"
results.append({
'test': i,
'input_words': len(text.split()),
'summary_lengths': summary_lengths,
'avg_length': avg_length,
'max_variation': max_variation,
'variation_percent': variation_percent,
'consistent': consistent,
'status': status
})
print(f" Consistency: {status} (Variation: {variation_percent:.1f}%)")
# Print summary of results
print("\n" + "=" * 60)
print("SUMMARY OF RESULTS")
print("=" * 60)
all_passed = all(result['consistent'] for result in results)
for result in results:
print(f"Test {result['test']}: {result['status']}")
print(f" Input: {result['input_words']} words")
print(f" Summaries: {result['summary_lengths']}")
print(f" Avg: {result['avg_length']:.1f}, Max variation: {result['max_variation']:.1f}")
print(f" Variation: {result['variation_percent']:.1f}%")
print()
print(f"OVERALL: {'ALL TESTS PASSED' if all_passed else 'SOME TESTS FAILED'}")
return all_passed
def test_edge_cases():
"""Test edge cases for the summarizer"""
print("\nTesting Edge Cases")
print("=" * 40)
mock_loader = create_mock_model_loader()
summarizer = SummarizerAgent(mock_loader)
# Test with very short text
short_text = "Patient has fever."
summary = summarizer.generate_summary(short_text)
print(f"Short text ('{short_text}'): '{summary}'")
# Test with empty text
empty_text = ""
summary = summarizer.generate_summary(empty_text)
print(f"Empty text: '{summary}'")
# Test with None
summary = summarizer.generate_summary(None)
print(f"None input: '{summary}'")
if __name__ == "__main__":
try:
# Run consistency tests
consistency_passed = test_summary_consistency()
# Run edge case tests
test_edge_cases()
# Exit with appropriate code
sys.exit(0 if consistency_passed else 1)
except Exception as e:
print(f"Error during testing: {e}")
import traceback
traceback.print_exc()
sys.exit(1)