File size: 14,494 Bytes
84f6785
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
from typing import Dict, List, Optional, Any, Union
import re
import json
import os
import glob
import time
import logging
import socket
import requests
import httpx
import backoff
from datetime import datetime
from tenacity import retry, wait_exponential, stop_after_attempt
from openai import OpenAI

# Configure model settings
MODEL_NAME = "meta-llama/llama-3.2-90b-vision-instruct"
temperature = 0.2

# Configure logging
log_filename = f"api_usage_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json"
logging.basicConfig(filename=log_filename, level=logging.INFO, format="%(message)s")


def verify_dns() -> bool:
    """Verify DNS resolution and connectivity.

    Returns:
        bool: True if DNS resolution succeeds, False otherwise
    """
    try:
        # Try to resolve openrouter.ai
        socket.gethostbyname("openrouter.ai")
        return True
    except socket.gaierror:
        print("DNS resolution failed. Trying to use Google DNS (8.8.8.8)...")
        # Modify resolv.conf to use Google DNS
        try:
            with open("/etc/resolv.conf", "w") as f:
                f.write("nameserver 8.8.8.8\n")
            return True
        except Exception as e:
            print(f"Failed to update DNS settings: {e}")
            return False


def verify_connection() -> bool:
    """Verify connection to OpenRouter API.

    Returns:
        bool: True if connection succeeds, False otherwise
    """
    try:
        response = requests.get("https://openrouter.ai/api/v1/status", timeout=10)
        return response.status_code == 200
    except Exception as e:
        print(f"Connection test failed: {e}")
        return False


def initialize_client() -> OpenAI:
    """Initialize the OpenRouter client with proper timeout settings and connection verification.

    Returns:
        OpenAI: Configured OpenAI client for OpenRouter

    Raises:
        ValueError: If OPENROUTER_API_KEY environment variable is not set
        ConnectionError: If DNS verification or connection test fails
    """
    api_key = os.getenv("OPENROUTER_API_KEY")
    if not api_key:
        raise ValueError("OPENROUTER_API_KEY environment variable is not set.")

    # Configure timeout settings for the client
    timeout_settings = 120  # Increased timeout for large images/responses

    # Verify DNS and connection
    if not verify_dns():
        raise ConnectionError("DNS verification failed. Please check your network settings.")

    if not verify_connection():
        raise ConnectionError(
            "Cannot connect to OpenRouter. Please check your internet connection."
        )

    # Set up client with retry and timeout settings
    return OpenAI(
        base_url="https://openrouter.ai/api/v1",
        api_key=api_key,
        timeout=timeout_settings,
        http_client=httpx.Client(
            timeout=timeout_settings, transport=httpx.HTTPTransport(retries=3)
        ),
    )


@backoff.on_exception(
    backoff.expo,
    (ConnectionError, TimeoutError, socket.gaierror, httpx.ConnectError),
    max_tries=5,
    max_time=300,  # Maximum total time to try in seconds
)
def create_multimodal_request(
    question_data: Dict[str, Any],
    case_details: Dict[str, Any],
    case_id: str,
    question_id: str,
    client: OpenAI,
) -> Optional[Any]:
    """Create and send a multimodal request to the model.

    Args:
        question_data: Dictionary containing question details
        case_details: Dictionary containing case information
        case_id: ID of the medical case
        question_id: ID of the specific question
        client: OpenAI client instance

    Returns:
        Optional[Any]: Model response if successful, None if skipped

    Raises:
        ConnectionError: If connection fails
        TimeoutError: If request times out
        Exception: For other errors
    """

    system_prompt = """You are a medical imaging expert. Your task is to provide ONLY a single letter answer.
Rules:
1. Respond with exactly one uppercase letter (A/B/C/D/E/F)
2. Do not add periods, explanations, or any other text
3. Do not use markdown or formatting
4. Do not restate the question
5. Do not explain your reasoning

Examples of valid responses:
A
B
C

Examples of invalid responses:
"A."
"Answer: B"
"C) This shows..."
"The answer is D"
"""

    prompt = f"""Given the following medical case:
Please answer this multiple choice question:
{question_data['question']}
Base your answer only on the provided images and case information."""

    # Parse required figures
    try:
        if isinstance(question_data["figures"], str):
            try:
                required_figures = json.loads(question_data["figures"])
            except json.JSONDecodeError:
                required_figures = [question_data["figures"]]
        elif isinstance(question_data["figures"], list):
            required_figures = question_data["figures"]
        else:
            required_figures = [str(question_data["figures"])]
    except Exception as e:
        print(f"Error parsing figures: {e}")
        required_figures = []

    required_figures = [
        fig if fig.startswith("Figure ") else f"Figure {fig}" for fig in required_figures
    ]

    # Process subfigures and prepare content
    content = [{"type": "text", "text": prompt}]
    image_urls = []
    image_captions = []

    for figure in required_figures:
        base_figure_num = "".join(filter(str.isdigit, figure))
        figure_letter = "".join(filter(str.isalpha, figure.split()[-1])) or None

        matching_figures = [
            case_figure
            for case_figure in case_details.get("figures", [])
            if case_figure["number"] == f"Figure {base_figure_num}"
        ]

        for case_figure in matching_figures:
            subfigures = []
            if figure_letter:
                subfigures = [
                    subfig
                    for subfig in case_figure.get("subfigures", [])
                    if subfig.get("number", "").lower().endswith(figure_letter.lower())
                    or subfig.get("label", "").lower() == figure_letter.lower()
                ]
            else:
                subfigures = case_figure.get("subfigures", [])

            for subfig in subfigures:
                if "url" in subfig:
                    content.append({"type": "image_url", "image_url": {"url": subfig["url"]}})
                    image_urls.append(subfig["url"])
                    image_captions.append(subfig.get("caption", ""))

    if len(content) == 1:  # Only the text prompt exists
        print(f"No images found for case {case_id}, question {question_id}")
        # Log the skipped question
        log_entry = {
            "case_id": case_id,
            "question_id": question_id,
            "timestamp": datetime.now().isoformat(),
            "model": MODEL_NAME,
            "status": "skipped",
            "reason": "no_images",
            "input": {
                "question_data": {
                    "question": question_data["question"],
                    "explanation": question_data["explanation"],
                    "metadata": question_data.get("metadata", {}),
                    "figures": question_data["figures"],
                },
                "image_urls": image_urls,
            },
        }
        logging.info(json.dumps(log_entry))
        return None

    try:
        start_time = time.time()

        response = client.chat.completions.create(
            model=MODEL_NAME,
            temperature=temperature,
            messages=[
                {"role": "system", "content": system_prompt},
                {"role": "user", "content": content},
            ],
        )
        duration = time.time() - start_time

        # Get raw response
        raw_answer = response.choices[0].message.content

        # Validate and clean
        clean_answer = validate_answer(raw_answer)

        if not clean_answer:
            print(f"Warning: Invalid response format for case {case_id}, question {question_id}")
            print(f"Raw response: {raw_answer}")

        # Update response object with cleaned answer
        response.choices[0].message.content = clean_answer

        # Log response
        log_entry = {
            "case_id": case_id,
            "question_id": question_id,
            "timestamp": datetime.now().isoformat(),
            "model": MODEL_NAME,
            "temperature": temperature,
            "duration": round(duration, 2),
            "usage": {
                "prompt_tokens": response.usage.prompt_tokens,
                "completion_tokens": response.usage.completion_tokens,
                "total_tokens": response.usage.total_tokens,
            },
            "model_answer": response.choices[0].message.content,
            "correct_answer": question_data["answer"],
            "input": {
                "question_data": {
                    "question": question_data["question"],
                    "explanation": question_data["explanation"],
                    "metadata": question_data.get("metadata", {}),
                    "figures": question_data["figures"],
                },
                "image_urls": image_urls,
            },
        }
        logging.info(json.dumps(log_entry))
        return response

    except ConnectionError as e:
        print(f"Connection error for case {case_id}, question {question_id}: {str(e)}")
        print("Retrying after a longer delay...")
        time.sleep(30)  # Add a longer delay before retry
        raise
    except TimeoutError as e:
        print(f"Timeout error for case {case_id}, question {question_id}: {str(e)}")
        print("Retrying with increased timeout...")
        raise
    except Exception as e:
        # Log failed requests too
        log_entry = {
            "case_id": case_id,
            "question_id": question_id,
            "timestamp": datetime.now().isoformat(),
            "model": MODEL_NAME,
            "temperature": temperature,
            "status": "error",
            "error": str(e),
            "input": {
                "question_data": {
                    "question": question_data["question"],
                    "explanation": question_data["explanation"],
                    "metadata": question_data.get("metadata", {}),
                    "figures": question_data["figures"],
                },
                "image_urls": image_urls,
            },
        }
        logging.info(json.dumps(log_entry))
        raise


def extract_answer(response_text: str) -> Optional[str]:
    """Extract single letter answer from model response.

    Args:
        response_text: Raw text response from model

    Returns:
        Optional[str]: Single letter answer if found, None otherwise
    """
    # Convert to uppercase and remove periods
    text = response_text.upper().replace(".", "")

    # Look for common patterns
    patterns = [
        r"ANSWER:\s*([A-F])",  # Matches "ANSWER: X"
        r"OPTION\s*([A-F])",  # Matches "OPTION X"
        r"([A-F])\)",  # Matches "X)"
        r"\b([A-F])\b",  # Matches single letter
    ]

    for pattern in patterns:
        matches = re.findall(pattern, text)
        if matches:
            return matches[0]

    return None


def validate_answer(response_text: str) -> Optional[str]:
    """Enforce strict single-letter response format.

    Args:
        response_text: Raw text response from model

    Returns:
        Optional[str]: Valid single letter answer if found, None otherwise
    """
    if not response_text:
        return None

    # Remove all whitespace and convert to uppercase
    cleaned = response_text.strip().upper()

    # Check if it's exactly one valid letter
    if len(cleaned) == 1 and cleaned in "ABCDEF":
        return cleaned

    # If not, try to extract just the letter
    match = re.search(r"([A-F])", cleaned)
    return match.group(1) if match else None


def load_benchmark_questions(case_id: str) -> List[str]:
    """Find all question files for a given case ID.

    Args:
        case_id: ID of the medical case

    Returns:
        List[str]: List of paths to question files
    """
    benchmark_dir = "../benchmark/questions"
    return glob.glob(f"{benchmark_dir}/{case_id}/{case_id}_*.json")


def count_total_questions() -> Tuple[int, int]:
    """Count total number of cases and questions.

    Returns:
        Tuple[int, int]: (total_cases, total_questions)
    """
    total_cases = len(glob.glob("../benchmark/questions/*"))
    total_questions = sum(
        len(glob.glob(f"../benchmark/questions/{case_id}/*.json"))
        for case_id in os.listdir("../benchmark/questions")
    )
    return total_cases, total_questions


def main():
    with open("../data/eurorad_metadata.json", "r") as file:
        data = json.load(file)

    client = initialize_client()
    total_cases, total_questions = count_total_questions()
    cases_processed = 0
    questions_processed = 0
    skipped_questions = 0

    print(f"Beginning benchmark evaluation for {MODEL_NAME} with temperature {temperature}")

    for case_id, case_details in data.items():
        question_files = load_benchmark_questions(case_id)
        if not question_files:
            continue

        cases_processed += 1
        for question_file in question_files:
            with open(question_file, "r") as file:
                question_data = json.load(file)
                question_id = os.path.basename(question_file).split(".")[0]

            questions_processed += 1
            response = create_multimodal_request(
                question_data, case_details, case_id, question_id, client
            )

            if response is None:
                skipped_questions += 1
                print(f"Skipped question: Case ID {case_id}, Question ID {question_id}")
                continue

            print(
                f"Progress: Case {cases_processed}/{total_cases}, Question {questions_processed}/{total_questions}"
            )
            print(f"Case ID: {case_id}")
            print(f"Question ID: {question_id}")
            print(f"Model Answer: {response.choices[0].message.content}")
            print(f"Correct Answer: {question_data['answer']}\n")

    print(f"\nBenchmark Summary:")
    print(f"Total Cases Processed: {cases_processed}")
    print(f"Total Questions Processed: {questions_processed}")
    print(f"Total Questions Skipped: {skipped_questions}")


if __name__ == "__main__":
    main()