Spaces:
Runtime error
Runtime error
from typing import Dict, List, Optional, Any, Union | |
import re | |
import json | |
import os | |
import glob | |
import time | |
import logging | |
import socket | |
import requests | |
import httpx | |
import backoff | |
from datetime import datetime | |
from tenacity import retry, wait_exponential, stop_after_attempt | |
from openai import OpenAI | |
# Configure model settings | |
MODEL_NAME = "meta-llama/llama-3.2-90b-vision-instruct" | |
temperature = 0.2 | |
# Configure logging | |
log_filename = f"api_usage_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json" | |
logging.basicConfig(filename=log_filename, level=logging.INFO, format="%(message)s") | |
def verify_dns() -> bool: | |
"""Verify DNS resolution and connectivity. | |
Returns: | |
bool: True if DNS resolution succeeds, False otherwise | |
""" | |
try: | |
# Try to resolve openrouter.ai | |
socket.gethostbyname("openrouter.ai") | |
return True | |
except socket.gaierror: | |
print("DNS resolution failed. Trying to use Google DNS (8.8.8.8)...") | |
# Modify resolv.conf to use Google DNS | |
try: | |
with open("/etc/resolv.conf", "w") as f: | |
f.write("nameserver 8.8.8.8\n") | |
return True | |
except Exception as e: | |
print(f"Failed to update DNS settings: {e}") | |
return False | |
def verify_connection() -> bool: | |
"""Verify connection to OpenRouter API. | |
Returns: | |
bool: True if connection succeeds, False otherwise | |
""" | |
try: | |
response = requests.get("https://openrouter.ai/api/v1/status", timeout=10) | |
return response.status_code == 200 | |
except Exception as e: | |
print(f"Connection test failed: {e}") | |
return False | |
def initialize_client() -> OpenAI: | |
"""Initialize the OpenRouter client with proper timeout settings and connection verification. | |
Returns: | |
OpenAI: Configured OpenAI client for OpenRouter | |
Raises: | |
ValueError: If OPENROUTER_API_KEY environment variable is not set | |
ConnectionError: If DNS verification or connection test fails | |
""" | |
api_key = os.getenv("OPENROUTER_API_KEY") | |
if not api_key: | |
raise ValueError("OPENROUTER_API_KEY environment variable is not set.") | |
# Configure timeout settings for the client | |
timeout_settings = 120 # Increased timeout for large images/responses | |
# Verify DNS and connection | |
if not verify_dns(): | |
raise ConnectionError("DNS verification failed. Please check your network settings.") | |
if not verify_connection(): | |
raise ConnectionError( | |
"Cannot connect to OpenRouter. Please check your internet connection." | |
) | |
# Set up client with retry and timeout settings | |
return OpenAI( | |
base_url="https://openrouter.ai/api/v1", | |
api_key=api_key, | |
timeout=timeout_settings, | |
http_client=httpx.Client( | |
timeout=timeout_settings, transport=httpx.HTTPTransport(retries=3) | |
), | |
) | |
def create_multimodal_request( | |
question_data: Dict[str, Any], | |
case_details: Dict[str, Any], | |
case_id: str, | |
question_id: str, | |
client: OpenAI, | |
) -> Optional[Any]: | |
"""Create and send a multimodal request to the model. | |
Args: | |
question_data: Dictionary containing question details | |
case_details: Dictionary containing case information | |
case_id: ID of the medical case | |
question_id: ID of the specific question | |
client: OpenAI client instance | |
Returns: | |
Optional[Any]: Model response if successful, None if skipped | |
Raises: | |
ConnectionError: If connection fails | |
TimeoutError: If request times out | |
Exception: For other errors | |
""" | |
system_prompt = """You are a medical imaging expert. Your task is to provide ONLY a single letter answer. | |
Rules: | |
1. Respond with exactly one uppercase letter (A/B/C/D/E/F) | |
2. Do not add periods, explanations, or any other text | |
3. Do not use markdown or formatting | |
4. Do not restate the question | |
5. Do not explain your reasoning | |
Examples of valid responses: | |
A | |
B | |
C | |
Examples of invalid responses: | |
"A." | |
"Answer: B" | |
"C) This shows..." | |
"The answer is D" | |
""" | |
prompt = f"""Given the following medical case: | |
Please answer this multiple choice question: | |
{question_data['question']} | |
Base your answer only on the provided images and case information.""" | |
# Parse required figures | |
try: | |
if isinstance(question_data["figures"], str): | |
try: | |
required_figures = json.loads(question_data["figures"]) | |
except json.JSONDecodeError: | |
required_figures = [question_data["figures"]] | |
elif isinstance(question_data["figures"], list): | |
required_figures = question_data["figures"] | |
else: | |
required_figures = [str(question_data["figures"])] | |
except Exception as e: | |
print(f"Error parsing figures: {e}") | |
required_figures = [] | |
required_figures = [ | |
fig if fig.startswith("Figure ") else f"Figure {fig}" for fig in required_figures | |
] | |
# Process subfigures and prepare content | |
content = [{"type": "text", "text": prompt}] | |
image_urls = [] | |
image_captions = [] | |
for figure in required_figures: | |
base_figure_num = "".join(filter(str.isdigit, figure)) | |
figure_letter = "".join(filter(str.isalpha, figure.split()[-1])) or None | |
matching_figures = [ | |
case_figure | |
for case_figure in case_details.get("figures", []) | |
if case_figure["number"] == f"Figure {base_figure_num}" | |
] | |
for case_figure in matching_figures: | |
subfigures = [] | |
if figure_letter: | |
subfigures = [ | |
subfig | |
for subfig in case_figure.get("subfigures", []) | |
if subfig.get("number", "").lower().endswith(figure_letter.lower()) | |
or subfig.get("label", "").lower() == figure_letter.lower() | |
] | |
else: | |
subfigures = case_figure.get("subfigures", []) | |
for subfig in subfigures: | |
if "url" in subfig: | |
content.append({"type": "image_url", "image_url": {"url": subfig["url"]}}) | |
image_urls.append(subfig["url"]) | |
image_captions.append(subfig.get("caption", "")) | |
if len(content) == 1: # Only the text prompt exists | |
print(f"No images found for case {case_id}, question {question_id}") | |
# Log the skipped question | |
log_entry = { | |
"case_id": case_id, | |
"question_id": question_id, | |
"timestamp": datetime.now().isoformat(), | |
"model": MODEL_NAME, | |
"status": "skipped", | |
"reason": "no_images", | |
"input": { | |
"question_data": { | |
"question": question_data["question"], | |
"explanation": question_data["explanation"], | |
"metadata": question_data.get("metadata", {}), | |
"figures": question_data["figures"], | |
}, | |
"image_urls": image_urls, | |
}, | |
} | |
logging.info(json.dumps(log_entry)) | |
return None | |
try: | |
start_time = time.time() | |
response = client.chat.completions.create( | |
model=MODEL_NAME, | |
temperature=temperature, | |
messages=[ | |
{"role": "system", "content": system_prompt}, | |
{"role": "user", "content": content}, | |
], | |
) | |
duration = time.time() - start_time | |
# Get raw response | |
raw_answer = response.choices[0].message.content | |
# Validate and clean | |
clean_answer = validate_answer(raw_answer) | |
if not clean_answer: | |
print(f"Warning: Invalid response format for case {case_id}, question {question_id}") | |
print(f"Raw response: {raw_answer}") | |
# Update response object with cleaned answer | |
response.choices[0].message.content = clean_answer | |
# Log response | |
log_entry = { | |
"case_id": case_id, | |
"question_id": question_id, | |
"timestamp": datetime.now().isoformat(), | |
"model": MODEL_NAME, | |
"temperature": temperature, | |
"duration": round(duration, 2), | |
"usage": { | |
"prompt_tokens": response.usage.prompt_tokens, | |
"completion_tokens": response.usage.completion_tokens, | |
"total_tokens": response.usage.total_tokens, | |
}, | |
"model_answer": response.choices[0].message.content, | |
"correct_answer": question_data["answer"], | |
"input": { | |
"question_data": { | |
"question": question_data["question"], | |
"explanation": question_data["explanation"], | |
"metadata": question_data.get("metadata", {}), | |
"figures": question_data["figures"], | |
}, | |
"image_urls": image_urls, | |
}, | |
} | |
logging.info(json.dumps(log_entry)) | |
return response | |
except ConnectionError as e: | |
print(f"Connection error for case {case_id}, question {question_id}: {str(e)}") | |
print("Retrying after a longer delay...") | |
time.sleep(30) # Add a longer delay before retry | |
raise | |
except TimeoutError as e: | |
print(f"Timeout error for case {case_id}, question {question_id}: {str(e)}") | |
print("Retrying with increased timeout...") | |
raise | |
except Exception as e: | |
# Log failed requests too | |
log_entry = { | |
"case_id": case_id, | |
"question_id": question_id, | |
"timestamp": datetime.now().isoformat(), | |
"model": MODEL_NAME, | |
"temperature": temperature, | |
"status": "error", | |
"error": str(e), | |
"input": { | |
"question_data": { | |
"question": question_data["question"], | |
"explanation": question_data["explanation"], | |
"metadata": question_data.get("metadata", {}), | |
"figures": question_data["figures"], | |
}, | |
"image_urls": image_urls, | |
}, | |
} | |
logging.info(json.dumps(log_entry)) | |
raise | |
def extract_answer(response_text: str) -> Optional[str]: | |
"""Extract single letter answer from model response. | |
Args: | |
response_text: Raw text response from model | |
Returns: | |
Optional[str]: Single letter answer if found, None otherwise | |
""" | |
# Convert to uppercase and remove periods | |
text = response_text.upper().replace(".", "") | |
# Look for common patterns | |
patterns = [ | |
r"ANSWER:\s*([A-F])", # Matches "ANSWER: X" | |
r"OPTION\s*([A-F])", # Matches "OPTION X" | |
r"([A-F])\)", # Matches "X)" | |
r"\b([A-F])\b", # Matches single letter | |
] | |
for pattern in patterns: | |
matches = re.findall(pattern, text) | |
if matches: | |
return matches[0] | |
return None | |
def validate_answer(response_text: str) -> Optional[str]: | |
"""Enforce strict single-letter response format. | |
Args: | |
response_text: Raw text response from model | |
Returns: | |
Optional[str]: Valid single letter answer if found, None otherwise | |
""" | |
if not response_text: | |
return None | |
# Remove all whitespace and convert to uppercase | |
cleaned = response_text.strip().upper() | |
# Check if it's exactly one valid letter | |
if len(cleaned) == 1 and cleaned in "ABCDEF": | |
return cleaned | |
# If not, try to extract just the letter | |
match = re.search(r"([A-F])", cleaned) | |
return match.group(1) if match else None | |
def load_benchmark_questions(case_id: str) -> List[str]: | |
"""Find all question files for a given case ID. | |
Args: | |
case_id: ID of the medical case | |
Returns: | |
List[str]: List of paths to question files | |
""" | |
benchmark_dir = "../benchmark/questions" | |
return glob.glob(f"{benchmark_dir}/{case_id}/{case_id}_*.json") | |
def count_total_questions() -> Tuple[int, int]: | |
"""Count total number of cases and questions. | |
Returns: | |
Tuple[int, int]: (total_cases, total_questions) | |
""" | |
total_cases = len(glob.glob("../benchmark/questions/*")) | |
total_questions = sum( | |
len(glob.glob(f"../benchmark/questions/{case_id}/*.json")) | |
for case_id in os.listdir("../benchmark/questions") | |
) | |
return total_cases, total_questions | |
def main(): | |
with open("../data/eurorad_metadata.json", "r") as file: | |
data = json.load(file) | |
client = initialize_client() | |
total_cases, total_questions = count_total_questions() | |
cases_processed = 0 | |
questions_processed = 0 | |
skipped_questions = 0 | |
print(f"Beginning benchmark evaluation for {MODEL_NAME} with temperature {temperature}") | |
for case_id, case_details in data.items(): | |
question_files = load_benchmark_questions(case_id) | |
if not question_files: | |
continue | |
cases_processed += 1 | |
for question_file in question_files: | |
with open(question_file, "r") as file: | |
question_data = json.load(file) | |
question_id = os.path.basename(question_file).split(".")[0] | |
questions_processed += 1 | |
response = create_multimodal_request( | |
question_data, case_details, case_id, question_id, client | |
) | |
if response is None: | |
skipped_questions += 1 | |
print(f"Skipped question: Case ID {case_id}, Question ID {question_id}") | |
continue | |
print( | |
f"Progress: Case {cases_processed}/{total_cases}, Question {questions_processed}/{total_questions}" | |
) | |
print(f"Case ID: {case_id}") | |
print(f"Question ID: {question_id}") | |
print(f"Model Answer: {response.choices[0].message.content}") | |
print(f"Correct Answer: {question_data['answer']}\n") | |
print(f"\nBenchmark Summary:") | |
print(f"Total Cases Processed: {cases_processed}") | |
print(f"Total Questions Processed: {questions_processed}") | |
print(f"Total Questions Skipped: {skipped_questions}") | |
if __name__ == "__main__": | |
main() | |