MedRAX-main / experiments /benchmark_llama.py
asbamit's picture
Upload folder using huggingface_hub
84f6785 verified
from typing import Dict, List, Optional, Any, Union
import re
import json
import os
import glob
import time
import logging
import socket
import requests
import httpx
import backoff
from datetime import datetime
from tenacity import retry, wait_exponential, stop_after_attempt
from openai import OpenAI
# Configure model settings
MODEL_NAME = "meta-llama/llama-3.2-90b-vision-instruct"
temperature = 0.2
# Configure logging
log_filename = f"api_usage_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json"
logging.basicConfig(filename=log_filename, level=logging.INFO, format="%(message)s")
def verify_dns() -> bool:
"""Verify DNS resolution and connectivity.
Returns:
bool: True if DNS resolution succeeds, False otherwise
"""
try:
# Try to resolve openrouter.ai
socket.gethostbyname("openrouter.ai")
return True
except socket.gaierror:
print("DNS resolution failed. Trying to use Google DNS (8.8.8.8)...")
# Modify resolv.conf to use Google DNS
try:
with open("/etc/resolv.conf", "w") as f:
f.write("nameserver 8.8.8.8\n")
return True
except Exception as e:
print(f"Failed to update DNS settings: {e}")
return False
def verify_connection() -> bool:
"""Verify connection to OpenRouter API.
Returns:
bool: True if connection succeeds, False otherwise
"""
try:
response = requests.get("https://openrouter.ai/api/v1/status", timeout=10)
return response.status_code == 200
except Exception as e:
print(f"Connection test failed: {e}")
return False
def initialize_client() -> OpenAI:
"""Initialize the OpenRouter client with proper timeout settings and connection verification.
Returns:
OpenAI: Configured OpenAI client for OpenRouter
Raises:
ValueError: If OPENROUTER_API_KEY environment variable is not set
ConnectionError: If DNS verification or connection test fails
"""
api_key = os.getenv("OPENROUTER_API_KEY")
if not api_key:
raise ValueError("OPENROUTER_API_KEY environment variable is not set.")
# Configure timeout settings for the client
timeout_settings = 120 # Increased timeout for large images/responses
# Verify DNS and connection
if not verify_dns():
raise ConnectionError("DNS verification failed. Please check your network settings.")
if not verify_connection():
raise ConnectionError(
"Cannot connect to OpenRouter. Please check your internet connection."
)
# Set up client with retry and timeout settings
return OpenAI(
base_url="https://openrouter.ai/api/v1",
api_key=api_key,
timeout=timeout_settings,
http_client=httpx.Client(
timeout=timeout_settings, transport=httpx.HTTPTransport(retries=3)
),
)
@backoff.on_exception(
backoff.expo,
(ConnectionError, TimeoutError, socket.gaierror, httpx.ConnectError),
max_tries=5,
max_time=300, # Maximum total time to try in seconds
)
def create_multimodal_request(
question_data: Dict[str, Any],
case_details: Dict[str, Any],
case_id: str,
question_id: str,
client: OpenAI,
) -> Optional[Any]:
"""Create and send a multimodal request to the model.
Args:
question_data: Dictionary containing question details
case_details: Dictionary containing case information
case_id: ID of the medical case
question_id: ID of the specific question
client: OpenAI client instance
Returns:
Optional[Any]: Model response if successful, None if skipped
Raises:
ConnectionError: If connection fails
TimeoutError: If request times out
Exception: For other errors
"""
system_prompt = """You are a medical imaging expert. Your task is to provide ONLY a single letter answer.
Rules:
1. Respond with exactly one uppercase letter (A/B/C/D/E/F)
2. Do not add periods, explanations, or any other text
3. Do not use markdown or formatting
4. Do not restate the question
5. Do not explain your reasoning
Examples of valid responses:
A
B
C
Examples of invalid responses:
"A."
"Answer: B"
"C) This shows..."
"The answer is D"
"""
prompt = f"""Given the following medical case:
Please answer this multiple choice question:
{question_data['question']}
Base your answer only on the provided images and case information."""
# Parse required figures
try:
if isinstance(question_data["figures"], str):
try:
required_figures = json.loads(question_data["figures"])
except json.JSONDecodeError:
required_figures = [question_data["figures"]]
elif isinstance(question_data["figures"], list):
required_figures = question_data["figures"]
else:
required_figures = [str(question_data["figures"])]
except Exception as e:
print(f"Error parsing figures: {e}")
required_figures = []
required_figures = [
fig if fig.startswith("Figure ") else f"Figure {fig}" for fig in required_figures
]
# Process subfigures and prepare content
content = [{"type": "text", "text": prompt}]
image_urls = []
image_captions = []
for figure in required_figures:
base_figure_num = "".join(filter(str.isdigit, figure))
figure_letter = "".join(filter(str.isalpha, figure.split()[-1])) or None
matching_figures = [
case_figure
for case_figure in case_details.get("figures", [])
if case_figure["number"] == f"Figure {base_figure_num}"
]
for case_figure in matching_figures:
subfigures = []
if figure_letter:
subfigures = [
subfig
for subfig in case_figure.get("subfigures", [])
if subfig.get("number", "").lower().endswith(figure_letter.lower())
or subfig.get("label", "").lower() == figure_letter.lower()
]
else:
subfigures = case_figure.get("subfigures", [])
for subfig in subfigures:
if "url" in subfig:
content.append({"type": "image_url", "image_url": {"url": subfig["url"]}})
image_urls.append(subfig["url"])
image_captions.append(subfig.get("caption", ""))
if len(content) == 1: # Only the text prompt exists
print(f"No images found for case {case_id}, question {question_id}")
# Log the skipped question
log_entry = {
"case_id": case_id,
"question_id": question_id,
"timestamp": datetime.now().isoformat(),
"model": MODEL_NAME,
"status": "skipped",
"reason": "no_images",
"input": {
"question_data": {
"question": question_data["question"],
"explanation": question_data["explanation"],
"metadata": question_data.get("metadata", {}),
"figures": question_data["figures"],
},
"image_urls": image_urls,
},
}
logging.info(json.dumps(log_entry))
return None
try:
start_time = time.time()
response = client.chat.completions.create(
model=MODEL_NAME,
temperature=temperature,
messages=[
{"role": "system", "content": system_prompt},
{"role": "user", "content": content},
],
)
duration = time.time() - start_time
# Get raw response
raw_answer = response.choices[0].message.content
# Validate and clean
clean_answer = validate_answer(raw_answer)
if not clean_answer:
print(f"Warning: Invalid response format for case {case_id}, question {question_id}")
print(f"Raw response: {raw_answer}")
# Update response object with cleaned answer
response.choices[0].message.content = clean_answer
# Log response
log_entry = {
"case_id": case_id,
"question_id": question_id,
"timestamp": datetime.now().isoformat(),
"model": MODEL_NAME,
"temperature": temperature,
"duration": round(duration, 2),
"usage": {
"prompt_tokens": response.usage.prompt_tokens,
"completion_tokens": response.usage.completion_tokens,
"total_tokens": response.usage.total_tokens,
},
"model_answer": response.choices[0].message.content,
"correct_answer": question_data["answer"],
"input": {
"question_data": {
"question": question_data["question"],
"explanation": question_data["explanation"],
"metadata": question_data.get("metadata", {}),
"figures": question_data["figures"],
},
"image_urls": image_urls,
},
}
logging.info(json.dumps(log_entry))
return response
except ConnectionError as e:
print(f"Connection error for case {case_id}, question {question_id}: {str(e)}")
print("Retrying after a longer delay...")
time.sleep(30) # Add a longer delay before retry
raise
except TimeoutError as e:
print(f"Timeout error for case {case_id}, question {question_id}: {str(e)}")
print("Retrying with increased timeout...")
raise
except Exception as e:
# Log failed requests too
log_entry = {
"case_id": case_id,
"question_id": question_id,
"timestamp": datetime.now().isoformat(),
"model": MODEL_NAME,
"temperature": temperature,
"status": "error",
"error": str(e),
"input": {
"question_data": {
"question": question_data["question"],
"explanation": question_data["explanation"],
"metadata": question_data.get("metadata", {}),
"figures": question_data["figures"],
},
"image_urls": image_urls,
},
}
logging.info(json.dumps(log_entry))
raise
def extract_answer(response_text: str) -> Optional[str]:
"""Extract single letter answer from model response.
Args:
response_text: Raw text response from model
Returns:
Optional[str]: Single letter answer if found, None otherwise
"""
# Convert to uppercase and remove periods
text = response_text.upper().replace(".", "")
# Look for common patterns
patterns = [
r"ANSWER:\s*([A-F])", # Matches "ANSWER: X"
r"OPTION\s*([A-F])", # Matches "OPTION X"
r"([A-F])\)", # Matches "X)"
r"\b([A-F])\b", # Matches single letter
]
for pattern in patterns:
matches = re.findall(pattern, text)
if matches:
return matches[0]
return None
def validate_answer(response_text: str) -> Optional[str]:
"""Enforce strict single-letter response format.
Args:
response_text: Raw text response from model
Returns:
Optional[str]: Valid single letter answer if found, None otherwise
"""
if not response_text:
return None
# Remove all whitespace and convert to uppercase
cleaned = response_text.strip().upper()
# Check if it's exactly one valid letter
if len(cleaned) == 1 and cleaned in "ABCDEF":
return cleaned
# If not, try to extract just the letter
match = re.search(r"([A-F])", cleaned)
return match.group(1) if match else None
def load_benchmark_questions(case_id: str) -> List[str]:
"""Find all question files for a given case ID.
Args:
case_id: ID of the medical case
Returns:
List[str]: List of paths to question files
"""
benchmark_dir = "../benchmark/questions"
return glob.glob(f"{benchmark_dir}/{case_id}/{case_id}_*.json")
def count_total_questions() -> Tuple[int, int]:
"""Count total number of cases and questions.
Returns:
Tuple[int, int]: (total_cases, total_questions)
"""
total_cases = len(glob.glob("../benchmark/questions/*"))
total_questions = sum(
len(glob.glob(f"../benchmark/questions/{case_id}/*.json"))
for case_id in os.listdir("../benchmark/questions")
)
return total_cases, total_questions
def main():
with open("../data/eurorad_metadata.json", "r") as file:
data = json.load(file)
client = initialize_client()
total_cases, total_questions = count_total_questions()
cases_processed = 0
questions_processed = 0
skipped_questions = 0
print(f"Beginning benchmark evaluation for {MODEL_NAME} with temperature {temperature}")
for case_id, case_details in data.items():
question_files = load_benchmark_questions(case_id)
if not question_files:
continue
cases_processed += 1
for question_file in question_files:
with open(question_file, "r") as file:
question_data = json.load(file)
question_id = os.path.basename(question_file).split(".")[0]
questions_processed += 1
response = create_multimodal_request(
question_data, case_details, case_id, question_id, client
)
if response is None:
skipped_questions += 1
print(f"Skipped question: Case ID {case_id}, Question ID {question_id}")
continue
print(
f"Progress: Case {cases_processed}/{total_cases}, Question {questions_processed}/{total_questions}"
)
print(f"Case ID: {case_id}")
print(f"Question ID: {question_id}")
print(f"Model Answer: {response.choices[0].message.content}")
print(f"Correct Answer: {question_data['answer']}\n")
print(f"\nBenchmark Summary:")
print(f"Total Cases Processed: {cases_processed}")
print(f"Total Questions Processed: {questions_processed}")
print(f"Total Questions Skipped: {skipped_questions}")
if __name__ == "__main__":
main()