MedRAX-main / experiments /benchmark_gpt4o.py
asbamit's picture
Upload folder using huggingface_hub
84f6785 verified
import json
import openai
import os
import glob
import time
import logging
from datetime import datetime
from tenacity import retry, wait_exponential, stop_after_attempt
model_name = "chatgpt-4o-latest"
temperature = 0.2
log_filename = f"api_usage_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json"
logging.basicConfig(filename=log_filename, level=logging.INFO, format="%(message)s")
def calculate_cost(
prompt_tokens: int, completion_tokens: int, model: str = "chatgpt-4o-latest"
) -> float:
"""Calculate the cost of API usage based on token counts.
Args:
prompt_tokens: Number of tokens in the prompt
completion_tokens: Number of tokens in the completion
model: Model name to use for pricing, defaults to chatgpt-4o-latest
Returns:
float: Cost in USD
"""
pricing = {"chatgpt-4o-latest": {"prompt": 5.0, "completion": 15.0}}
rates = pricing.get(model, {"prompt": 5.0, "completion": 15.0})
return (prompt_tokens * rates["prompt"] + completion_tokens * rates["completion"]) / 1000000
@retry(wait=wait_exponential(multiplier=1, min=4, max=10), stop=stop_after_attempt(3))
def create_multimodal_request(
question_data: dict, case_details: dict, case_id: str, question_id: str, client: openai.OpenAI
) -> openai.types.chat.ChatCompletion:
"""Create and send a multimodal request to the OpenAI API.
Args:
question_data: Dictionary containing question details and figures
case_details: Dictionary containing case information and figures
case_id: Identifier for the medical case
question_id: Identifier for the specific question
client: OpenAI client instance
Returns:
openai.types.chat.ChatCompletion: API response object, or None if request fails
"""
prompt = f"""Given the following medical case:
Please answer this multiple choice question:
{question_data['question']}
Base your answer only on the provided images and case information."""
content = [{"type": "text", "text": prompt}]
# Parse required figures
try:
# Try multiple ways of parsing figures
if isinstance(question_data["figures"], str):
try:
required_figures = json.loads(question_data["figures"])
except json.JSONDecodeError:
required_figures = [question_data["figures"]]
elif isinstance(question_data["figures"], list):
required_figures = question_data["figures"]
else:
required_figures = [str(question_data["figures"])]
except Exception as e:
print(f"Error parsing figures: {e}")
required_figures = []
# Ensure each figure starts with "Figure "
required_figures = [
fig if fig.startswith("Figure ") else f"Figure {fig}" for fig in required_figures
]
subfigures = []
for figure in required_figures:
# Handle both regular figures and those with letter suffixes
base_figure_num = "".join(filter(str.isdigit, figure))
figure_letter = "".join(filter(str.isalpha, figure.split()[-1])) or None
# Find matching figures in case details
matching_figures = [
case_figure
for case_figure in case_details.get("figures", [])
if case_figure["number"] == f"Figure {base_figure_num}"
]
if not matching_figures:
print(f"No matching figure found for {figure} in case {case_id}")
continue
for case_figure in matching_figures:
# If a specific letter is specified, filter subfigures
if figure_letter:
matching_subfigures = [
subfig
for subfig in case_figure.get("subfigures", [])
if subfig.get("number", "").lower().endswith(figure_letter.lower())
or subfig.get("label", "").lower() == figure_letter.lower()
]
subfigures.extend(matching_subfigures)
else:
# If no letter specified, add all subfigures
subfigures.extend(case_figure.get("subfigures", []))
# Add images to content
for subfig in subfigures:
if "url" in subfig:
content.append({"type": "image_url", "image_url": {"url": subfig["url"]}})
else:
print(f"Subfigure missing URL: {subfig}")
# If no images found, log and return None
if len(content) == 1: # Only the text prompt exists
print(f"No images found for case {case_id}, question {question_id}")
return None
messages = [
{
"role": "system",
"content": "You are a medical imaging expert. Provide only the letter corresponding to your answer choice (A/B/C/D/E/F).",
},
{"role": "user", "content": content},
]
if len(content) == 1: # Only the text prompt exists
print(f"No images found for case {case_id}, question {question_id}")
log_entry = {
"case_id": case_id,
"question_id": question_id,
"timestamp": datetime.now().isoformat(),
"model": model_name,
"temperature": temperature,
"status": "skipped",
"reason": "no_images",
"cost": 0,
"input": {
"messages": messages,
"question_data": {
"question": question_data["question"],
"explanation": question_data["explanation"],
"metadata": question_data.get("metadata", {}),
"figures": question_data["figures"],
},
"image_urls": [subfig["url"] for subfig in subfigures if "url" in subfig],
"image_captions": [subfig.get("caption", "") for subfig in subfigures],
},
}
logging.info(json.dumps(log_entry))
return None
try:
start_time = time.time()
response = client.chat.completions.create(
model=model_name, messages=messages, max_tokens=50, temperature=temperature
)
duration = time.time() - start_time
log_entry = {
"case_id": case_id,
"question_id": question_id,
"timestamp": datetime.now().isoformat(),
"model": model_name,
"temperature": temperature,
"duration": round(duration, 2),
"usage": {
"prompt_tokens": response.usage.prompt_tokens,
"completion_tokens": response.usage.completion_tokens,
"total_tokens": response.usage.total_tokens,
},
"cost": calculate_cost(response.usage.prompt_tokens, response.usage.completion_tokens),
"model_answer": response.choices[0].message.content,
"correct_answer": question_data["answer"],
"input": {
"messages": messages,
"question_data": {
"question": question_data["question"],
"explanation": question_data["explanation"],
"metadata": question_data.get("metadata", {}),
"figures": question_data["figures"],
},
"image_urls": [subfig["url"] for subfig in subfigures if "url" in subfig],
"image_captions": [subfig.get("caption", "") for subfig in subfigures],
},
}
logging.info(json.dumps(log_entry))
return response
except openai.RateLimitError:
log_entry = {
"case_id": case_id,
"question_id": question_id,
"timestamp": datetime.now().isoformat(),
"model": model_name,
"temperature": temperature,
"status": "error",
"reason": "rate_limit",
"cost": 0,
"input": {
"messages": messages,
"question_data": {
"question": question_data["question"],
"explanation": question_data["explanation"],
"metadata": question_data.get("metadata", {}),
"figures": question_data["figures"],
},
"image_urls": [subfig["url"] for subfig in subfigures if "url" in subfig],
"image_captions": [subfig.get("caption", "") for subfig in subfigures],
},
}
logging.info(json.dumps(log_entry))
print(
f"\nRate limit hit for case {case_id}, question {question_id}. Waiting 20s...",
flush=True,
)
time.sleep(20)
raise
except Exception as e:
log_entry = {
"case_id": case_id,
"question_id": question_id,
"timestamp": datetime.now().isoformat(),
"model": model_name,
"temperature": temperature,
"status": "error",
"error": str(e),
"cost": 0,
"input": {
"messages": messages,
"question_data": {
"question": question_data["question"],
"explanation": question_data["explanation"],
"metadata": question_data.get("metadata", {}),
"figures": question_data["figures"],
},
"image_urls": [subfig["url"] for subfig in subfigures if "url" in subfig],
"image_captions": [subfig.get("caption", "") for subfig in subfigures],
},
}
logging.info(json.dumps(log_entry))
print(f"Error processing case {case_id}, question {question_id}: {str(e)}")
raise
def load_benchmark_questions(case_id: str) -> list:
"""Load benchmark questions for a given case.
Args:
case_id: Identifier for the medical case
Returns:
list: List of paths to question files
"""
benchmark_dir = "../benchmark/questions"
return glob.glob(f"{benchmark_dir}/{case_id}/{case_id}_*.json")
def count_total_questions() -> tuple[int, int]:
"""Count total number of cases and questions in benchmark.
Returns:
tuple: (total_cases, total_questions)
"""
total_cases = len(glob.glob("../benchmark/questions/*"))
total_questions = sum(
len(glob.glob(f"../benchmark/questions/{case_id}/*.json"))
for case_id in os.listdir("../benchmark/questions")
)
return total_cases, total_questions
def main() -> None:
"""Main function to run the benchmark evaluation."""
with open("../data/eurorad_metadata.json", "r") as file:
data = json.load(file)
api_key = os.getenv("OPENAI_API_KEY")
if not api_key:
raise ValueError("OPENAI_API_KEY environment variable is not set.")
global client
client = openai.OpenAI(api_key=api_key)
total_cases, total_questions = count_total_questions()
cases_processed = 0
questions_processed = 0
skipped_questions = 0
print(f"Beginning benchmark evaluation for model {model_name} with temperature {temperature}")
for case_id, case_details in data.items():
question_files = load_benchmark_questions(case_id)
if not question_files:
continue
cases_processed += 1
for question_file in question_files:
with open(question_file, "r") as file:
question_data = json.load(file)
question_id = os.path.basename(question_file).split(".")[0]
questions_processed += 1
response = create_multimodal_request(
question_data, case_details, case_id, question_id, client
)
# Handle cases where response is None
if response is None:
skipped_questions += 1
print(f"Skipped question: Case ID {case_id}, Question ID {question_id}")
continue
print(
f"Progress: Case {cases_processed}/{total_cases}, Question {questions_processed}/{total_questions}"
)
print(f"Case ID: {case_id}")
print(f"Question ID: {question_id}")
print(f"Model Answer: {response.choices[0].message.content}")
print(f"Correct Answer: {question_data['answer']}\n")
print(f"\nBenchmark Summary:")
print(f"Total Cases Processed: {cases_processed}")
print(f"Total Questions Processed: {questions_processed}")
print(f"Total Questions Skipped: {skipped_questions}")
if __name__ == "__main__":
main()