MedRAX-main / experiments /analyze_axes.py
asbamit's picture
Upload folder using huggingface_hub
84f6785 verified
from typing import Dict, List, Optional, Tuple, Union, Any
import json
import os
import sys
import argparse
from collections import defaultdict
from tqdm import tqdm
QUESTION_TYPES = {
"Detailed Finding Analysis": ["detection", "localization", "characterization"],
"Pattern Recognition & Relations": ["detection", "classification", "relationship"],
"Spatial Understanding": ["localization", "comparison", "relationship"],
"Clinical Decision Making": ["classification", "comparison", "diagnosis"],
"Diagnostic Classification": ["classification", "characterization", "diagnosis"],
}
def extract_answer_letter(answer: Optional[Union[str, Any]]) -> Optional[str]:
"""
Extract just the letter from various answer formats.
Args:
answer: The answer text to extract letter from
Returns:
Optional[str]: The extracted letter in uppercase, or None if no letter found
"""
if not answer:
return None
# Convert to string and clean
answer = str(answer).strip()
# If it's just a single letter, return it
if len(answer) == 1 and answer.isalpha():
return answer.upper()
# Try to extract letter from format like "A)" or "A."
if len(answer) >= 2 and answer[0].isalpha() and answer[1] in ").:- ":
return answer[0].upper()
# Try to extract letter from format like "A) Some text"
if answer.startswith(("A)", "B)", "C)", "D)", "E)", "F)")):
return answer[0].upper()
return None
def analyze_gpt4_results(
results_file: str, max_questions: Optional[int] = None
) -> Tuple[float, Dict, Dict, List[str], List[str]]:
"""
Analyze results in GPT-4 format.
Args:
results_file: Path to results file
max_questions: Maximum number of questions to analyze
Returns:
Tuple containing:
- overall_accuracy (float)
- category_accuracies (Dict)
- question_type_stats (Dict)
- correct_ids (List[str])
- incorrect_ids (List[str])
"""
category_performance = defaultdict(lambda: {"total": 0, "correct": 0})
all_questions = 0
all_correct = 0
correct_ids = []
incorrect_ids = []
with open(results_file, "r") as f:
lines = f.readlines()
processed_questions = 0
for line in tqdm(lines, desc="Analyzing Benchmark Results"):
# Check if we've hit the maximum questions
if max_questions is not None and processed_questions >= max_questions:
break
if line.startswith("HTTP Request:"):
continue
try:
entry = json.loads(line)
metadata = entry.get("input", {}).get("question_data", {}).get("metadata", {})
question_id = entry.get("question_id")
model_letter = extract_answer_letter(entry.get("model_answer"))
correct_letter = extract_answer_letter(entry.get("correct_answer"))
if model_letter and correct_letter:
all_questions += 1
processed_questions += 1
is_correct = model_letter == correct_letter
if is_correct:
all_correct += 1
correct_ids.append(question_id)
else:
incorrect_ids.append(question_id)
for category in metadata.get("categories", []):
category_performance[category]["total"] += 1
if is_correct:
category_performance[category]["correct"] += 1
except json.JSONDecodeError:
continue
return process_results(
category_performance, all_questions, all_correct, correct_ids, incorrect_ids
)
def analyze_llama_results(
results_file: str, max_questions: Optional[int] = None
) -> Tuple[float, Dict, Dict, List[str], List[str]]:
"""
Analyze results in Llama format.
Args:
results_file: Path to results file
max_questions: Maximum number of questions to analyze
Returns:
Tuple containing:
- overall_accuracy (float)
- category_accuracies (Dict)
- question_type_stats (Dict)
- correct_ids (List[str])
- incorrect_ids (List[str])
"""
category_performance = defaultdict(lambda: {"total": 0, "correct": 0})
all_questions = 0
all_correct = 0
correct_ids = []
incorrect_ids = []
with open(results_file, "r") as f:
lines = f.readlines()
# If max_questions is set, limit the number of lines processed
if max_questions is not None:
lines = lines[:max_questions]
for line in tqdm(lines, desc="Analyzing Benchmark Results"):
if line.startswith("HTTP Request:"):
continue
try:
entry = json.loads(line)
metadata = entry.get("input", {}).get("question_data", {}).get("metadata", {})
question_id = entry.get("question_id")
model_letter = extract_answer_letter(entry.get("model_answer"))
correct_letter = extract_answer_letter(entry.get("correct_answer"))
if model_letter and correct_letter:
all_questions += 1
is_correct = model_letter == correct_letter
if is_correct:
all_correct += 1
correct_ids.append(question_id)
else:
incorrect_ids.append(question_id)
for category in metadata.get("categories", []):
category_performance[category]["total"] += 1
if is_correct:
category_performance[category]["correct"] += 1
except json.JSONDecodeError:
continue
return process_results(
category_performance, all_questions, all_correct, correct_ids, incorrect_ids
)
def analyze_chexagent_results(
results_file: str, max_questions: Optional[int] = None
) -> Tuple[float, Dict, Dict, List[str], List[str]]:
"""
Analyze results in CheXagent format.
Args:
results_file: Path to results file
max_questions: Maximum number of questions to analyze
Returns:
Tuple containing:
- overall_accuracy (float)
- category_accuracies (Dict)
- question_type_stats (Dict)
- correct_ids (List[str])
- incorrect_ids (List[str])
"""
category_performance = defaultdict(lambda: {"total": 0, "correct": 0})
all_questions = 0
all_correct = 0
correct_ids = []
incorrect_ids = []
with open(results_file, "r") as f:
lines = f.readlines()
# If max_questions is set, limit the number of lines processed
if max_questions is not None:
lines = lines[:max_questions]
for line in tqdm(lines, desc="Analyzing Benchmark Results"):
try:
entry = json.loads(line)
metadata = entry.get("input", {}).get("question_data", {}).get("metadata", {})
question_id = entry.get("question_id")
model_letter = extract_answer_letter(entry.get("model_answer"))
correct_letter = extract_answer_letter(entry.get("correct_answer"))
if model_letter and correct_letter:
all_questions += 1
is_correct = model_letter == correct_letter
if is_correct:
all_correct += 1
correct_ids.append(question_id)
else:
incorrect_ids.append(question_id)
for category in metadata.get("categories", []):
category_performance[category]["total"] += 1
if is_correct:
category_performance[category]["correct"] += 1
except json.JSONDecodeError:
continue
return process_results(
category_performance, all_questions, all_correct, correct_ids, incorrect_ids
)
def process_results(
category_performance: Dict,
all_questions: int,
all_correct: int,
correct_ids: Optional[List[str]] = None,
incorrect_ids: Optional[List[str]] = None,
) -> Tuple[float, Dict, Dict, List[str], List[str]]:
"""
Process raw results into final statistics.
Args:
category_performance: Dict containing performance by category
all_questions: Total number of questions
all_correct: Total number of correct answers
correct_ids: List of IDs for correctly answered questions
incorrect_ids: List of IDs for incorrectly answered questions
Returns:
Tuple containing:
- overall_accuracy (float)
- category_accuracies (Dict)
- question_type_stats (Dict)
- correct_ids (List[str])
- incorrect_ids (List[str])
"""
category_accuracies = {
category: {
"accuracy": stats["correct"] / stats["total"] * 100 if stats["total"] > 0 else 0,
"total": stats["total"],
"correct": stats["correct"],
}
for category, stats in category_performance.items()
}
question_type_stats = {}
for qtype, categories in QUESTION_TYPES.items():
total = sum(
category_performance[cat]["total"] for cat in categories if cat in category_performance
)
correct = sum(
category_performance[cat]["correct"]
for cat in categories
if cat in category_performance
)
question_type_stats[qtype] = {
"accuracy": (correct / total * 100) if total > 0 else 0,
"total": total,
"correct": correct,
}
overall_accuracy = (all_correct / all_questions * 100) if all_questions > 0 else 0
return (
overall_accuracy,
category_accuracies,
question_type_stats,
correct_ids or [],
incorrect_ids or [],
)
def print_analysis(
overall_accuracy: float,
category_accuracies: Dict,
question_type_stats: Dict,
correct_ids: List[str],
incorrect_ids: List[str],
model_name: str,
) -> None:
"""
Print analysis results.
Args:
overall_accuracy: Overall accuracy percentage
category_accuracies: Dict containing accuracy metrics by category
question_type_stats: Dict containing stats by question type
correct_ids: List of IDs for correctly answered questions
incorrect_ids: List of IDs for incorrectly answered questions
model_name: Name of the model being analyzed
"""
total_questions = len(correct_ids) + len(incorrect_ids)
print(
f"\nOverall Accuracy: {overall_accuracy:.2f}% ({len(correct_ids)} correct out of {total_questions} questions)"
)
print("\nCategory Performance:")
sorted_categories = sorted(
category_accuracies.items(), key=lambda x: x[1]["accuracy"], reverse=True
)
for category, metrics in sorted_categories:
print(f"{category}:")
print(f" Accuracy: {metrics['accuracy']:.2f}%")
print(f" Total Questions: {metrics['total']}")
print(f" Correct Questions: {metrics['correct']}")
print("\nQuestion Type Performance:")
sorted_types = sorted(question_type_stats.items(), key=lambda x: x[1]["accuracy"], reverse=True)
for qtype, metrics in sorted_types:
print(f"\n{qtype}:")
print(f" Accuracy: {metrics['accuracy']:.2f}%")
print(f" Total Questions: {metrics['total']}")
print(f" Correct Questions: {metrics['correct']}")
print(f" Categories: {', '.join(QUESTION_TYPES[qtype])}")
# Save question IDs to JSON
question_ids = {"correct_ids": correct_ids, "incorrect_ids": incorrect_ids}
output_filename = f"{model_name}_question_ids.json"
with open(output_filename, "w") as f:
json.dump(question_ids, f, indent=2)
print(f"\nQuestion IDs have been saved to {output_filename}")
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Analyze benchmark results")
parser.add_argument("results_file", help="Path to results file")
parser.add_argument("benchmark_dir", nargs="?", help="Path to benchmark questions directory")
parser.add_argument(
"--model",
choices=["llava-med", "chexagent", "llama", "gpt4", "medrax"],
default="gpt4",
help="Specify model format (default: gpt4)",
)
parser.add_argument("--max-questions", type=int, help="Maximum number of questions to analyze")
args = parser.parse_args()
if args.model == "gpt4":
results = analyze_gpt4_results(args.results_file, args.max_questions)
elif args.model == "llama":
results = analyze_llama_results(args.results_file, args.max_questions)
elif args.model == "chexagent":
results = analyze_chexagent_results(args.results_file, args.max_questions)
elif args.model == "medrax":
results = analyze_gpt4_results(args.results_file, args.max_questions)
else:
parser.error(f"Unsupported model: {args.model}")
print_analysis(*results, args.model)