TianheWu commited on
Commit
c10e6b1
·
verified ·
1 Parent(s): 4820de8

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +509 -29
README.md CHANGED
@@ -21,47 +21,265 @@ This is a demo version of VisualQuality-R1 which is trained on the combination o
21
 
22
  Paper link: [arXiv](https://arxiv.org/abs/2505.14460)
23
 
24
- ## Quick Start
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
  ```python
26
  from transformers import Qwen2_5_VLForConditionalGeneration, AutoTokenizer, AutoProcessor
27
  from qwen_vl_utils import process_vision_info
28
 
29
- import json
30
- import numpy as np
31
  import torch
32
  import random
33
  import re
34
  import os
35
 
36
 
37
- def score_image(model_path, image_path):
38
- model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
39
- model_path,
40
- torch_dtype=torch.bfloat16,
41
- attn_implementation="flash_attention_2",
42
- device_map=device,
43
  )
44
- processor = AutoProcessor.from_pretrained(MODEL_PATH)
45
- processor.tokenizer.padding_side = "left"
 
 
 
 
 
 
 
 
 
 
 
46
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
  PROMPT = (
48
  "You are doing the image quality assessment task. Here is the question: "
49
  "What is your overall rating on the quality of this picture? The rating should be a float between 1 and 5, "
50
  "rounded to two decimal places, with 1 representing very poor quality and 5 representing excellent quality."
51
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52
 
53
- x = {
54
- "image": [image_path],
55
- "question": PROMPT,
56
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
 
58
  QUESTION_TEMPLATE = "{Question} First output the thinking process in <think> </think> tags and then output the final answer with only one score in <answer> </answer> tags."
 
59
  message = [
60
  {
61
  "role": "user",
62
  "content": [
63
- *({'type': 'image', 'image': img_path} for img_path in x['image']),
64
- {"type": "text", "text": QUESTION_TEMPLATE.format(Question=x['question'])}
65
  ],
66
  }
67
  ]
@@ -81,7 +299,7 @@ def score_image(model_path, image_path):
81
  inputs = inputs.to(device)
82
 
83
  # Inference: Generation of the output
84
- generated_ids = model.generate(**inputs, use_cache=True, max_new_tokens=256, do_sample=True)
85
  generated_ids_trimmed = [
86
  out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
87
  ]
@@ -92,25 +310,287 @@ def score_image(model_path, image_path):
92
  reasoning = re.findall(r'<think>(.*?)</think>', batch_output_text[0], re.DOTALL)
93
  reasoning = reasoning[-1].strip()
94
 
95
- model_output_matches = re.findall(r'<answer>(.*?)</answer>', batch_output_text[0], re.DOTALL)
96
- model_answer = model_output_matches[-1].strip()
97
- score = float(re.search(r'\d+(\.\d+)?', model_answer).group())
 
 
 
 
98
 
99
  return reasoning, score
100
 
101
 
102
- random.seed(42)
103
- device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
104
-
105
- ### Modify here
106
- model_path = ""
107
  image_path = ""
108
 
 
 
 
 
 
 
 
 
 
109
  reasoning, score = score_image(
110
- model_path=model_path,
111
- image_path=image_path
112
  )
113
 
114
  print(reasoning)
115
  print(score)
116
- ```
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
 
22
  Paper link: [arXiv](https://arxiv.org/abs/2505.14460)
23
 
24
+ ## Quick Start
25
+
26
+ ### Non-Thinking Inference
27
+ When you execute inference with VisualQuality-R1 as a reward/evaluation model, you can only use **non-thinking** mode to reduce inference time, generating only a single output token with the following prompt:
28
+ ```
29
+ PROMPT = (
30
+ "You are doing the image quality assessment task. Here is the question: "
31
+ "What is your overall rating on the quality of this picture? The rating should be a float between 1 and 5, "
32
+ "rounded to two decimal places, with 1 representing very poor quality and 5 representing excellent quality."
33
+ )
34
+
35
+ QUESTION_TEMPLATE = "{Question} Please only output the final answer with only one score in <answer> </answer> tags."
36
+ ```
37
+
38
+ For single image quality rating, the code is:
39
+
40
+ <details>
41
+ <summary>Example Code (VisualQuality-R1: Image Quality Rating with non-thinking mode)</summary>
42
+
43
  ```python
44
  from transformers import Qwen2_5_VLForConditionalGeneration, AutoTokenizer, AutoProcessor
45
  from qwen_vl_utils import process_vision_info
46
 
 
 
47
  import torch
48
  import random
49
  import re
50
  import os
51
 
52
 
53
+ def score_image(image_path, model, processor):
54
+ PROMPT = (
55
+ "You are doing the image quality assessment task. Here is the question: "
56
+ "What is your overall rating on the quality of this picture? The rating should be a float between 1 and 5, "
57
+ "rounded to two decimal places, with 1 representing very poor quality and 5 representing excellent quality. "
58
+ "First output the thinking process in <think> </think> tags and then output the final answer with only one score in <answer> </answer> tags."
59
  )
60
+
61
+ QUESTION_TEMPLATE = "{Question} Please only output the final answer with only one score in <answer> </answer> tags."
62
+ message = [
63
+ {
64
+ "role": "user",
65
+ "content": [
66
+ {'type': 'image', 'image': image_path},
67
+ {"type": "text", "text": PROMPT}
68
+ ],
69
+ }
70
+ ]
71
+
72
+ batch_messages = [message]
73
 
74
+ # Preparation for inference
75
+ text = [processor.apply_chat_template(msg, tokenize=False, add_generation_prompt=True, add_vision_id=True) for msg in batch_messages]
76
+ image_inputs, video_inputs = process_vision_info(batch_messages)
77
+ inputs = processor(
78
+ text=text,
79
+ images=image_inputs,
80
+ videos=video_inputs,
81
+ padding=True,
82
+ return_tensors="pt",
83
+ )
84
+ inputs = inputs.to(device)
85
+
86
+ # Inference: Generation of the output
87
+ generated_ids = model.generate(**inputs, use_cache=True, max_new_tokens=2048, do_sample=True, top_k=50, top_p=1)
88
+ generated_ids_trimmed = [
89
+ out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
90
+ ]
91
+ batch_output_text = processor.batch_decode(
92
+ generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
93
+ )
94
+
95
+ reasoning = None
96
+
97
+ try:
98
+ model_output_matches = re.findall(r'<answer>(.*?)</answer>', batch_output_text[0], re.DOTALL)
99
+ model_answer = model_output_matches[-1].strip() if model_output_matches else batch_output_text[0].strip()
100
+ score = float(re.search(r'\d+(\.\d+)?', model_answer).group())
101
+ except:
102
+ print(f"================= Meet error with {img_path}, please generate again. =================")
103
+ score = random.randint(1, 5)
104
+
105
+ return reasoning, score
106
+
107
+
108
+ random.seed(1)
109
+ MODEL_PATH = ""
110
+ device = torch.device("cuda:5") if torch.cuda.is_available() else torch.device("cpu")
111
+ image_path = ""
112
+
113
+ model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
114
+ MODEL_PATH,
115
+ torch_dtype=torch.bfloat16,
116
+ attn_implementation="flash_attention_2",
117
+ device_map=device,
118
+ )
119
+ processor = AutoProcessor.from_pretrained(MODEL_PATH)
120
+ processor.tokenizer.padding_side = "left"
121
+
122
+ reasoning, score = score_image(
123
+ image_path, model, processor
124
+ )
125
+
126
+ print(score)
127
+ ```
128
+ </details>
129
+
130
+
131
+ <details>
132
+ <summary>Example Code (VisualQuality-R1: Batch Images Quality Rating with non-thinking mode)</summary>
133
+
134
+ ```python
135
+ from transformers import Qwen2_5_VLForConditionalGeneration, AutoTokenizer, AutoProcessor
136
+ from qwen_vl_utils import process_vision_info
137
+ from tqdm import tqdm
138
+
139
+ import torch
140
+ import random
141
+ import re
142
+ import os
143
+
144
+
145
+ def get_image_paths(folder_path):
146
+ image_extensions = {'.jpg', '.jpeg', '.png', '.bmp', '.gif', '.tiff', '.webp'}
147
+ image_paths = []
148
+
149
+ for root, dirs, files in os.walk(folder_path):
150
+ for file in files:
151
+ _, ext = os.path.splitext(file)
152
+ if ext.lower() in image_extensions:
153
+ image_paths.append(os.path.join(root, file))
154
+
155
+ return image_paths
156
+
157
+ def score_batch_image(image_paths, model, processor):
158
  PROMPT = (
159
  "You are doing the image quality assessment task. Here is the question: "
160
  "What is your overall rating on the quality of this picture? The rating should be a float between 1 and 5, "
161
  "rounded to two decimal places, with 1 representing very poor quality and 5 representing excellent quality."
162
  )
163
+
164
+ QUESTION_TEMPLATE = "{Question} Please only output the final answer with only one score in <answer> </answer> tags."
165
+
166
+ messages = []
167
+ for img_path in image_paths:
168
+ message = [
169
+ {
170
+ "role": "user",
171
+ "content": [
172
+ {'type': 'image', 'image': img_path},
173
+ {"type": "text", "text": QUESTION_TEMPLATE.format(Question=PROMPT)}
174
+ ],
175
+ }
176
+ ]
177
+ messages.append(message)
178
+
179
+ BSZ = 32
180
+ all_outputs = [] # List to store all answers
181
+ for i in tqdm(range(0, len(messages), BSZ)):
182
+ batch_messages = messages[i:i + BSZ]
183
+
184
+ # Preparation for inference
185
+ text = [processor.apply_chat_template(msg, tokenize=False, add_generation_prompt=True, add_vision_id=True) for msg in batch_messages]
186
 
187
+ image_inputs, video_inputs = process_vision_info(batch_messages)
188
+ inputs = processor(
189
+ text=text,
190
+ images=image_inputs,
191
+ videos=video_inputs,
192
+ padding=True,
193
+ return_tensors="pt",
194
+ )
195
+ inputs = inputs.to(device)
196
+
197
+ # Inference: Generation of the output
198
+ generated_ids = model.generate(**inputs, use_cache=True, max_new_tokens=512, do_sample=True, top_k=50, top_p=1)
199
+ generated_ids_trimmed = [
200
+ out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
201
+ ]
202
+ batch_output_text = processor.batch_decode(
203
+ generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
204
+ )
205
+
206
+ all_outputs.extend(batch_output_text)
207
+
208
+ path_score_dict = {}
209
+ for img_path, model_output in zip(image_paths, all_outputs):
210
+ try:
211
+ model_output_matches = re.findall(r'<answer>(.*?)</answer>', model_output, re.DOTALL)
212
+ model_answer = model_output_matches[-1].strip() if model_output_matches else model_output.strip()
213
+ score = float(re.search(r'\d+(\.\d+)?', model_answer).group())
214
+ except:
215
+ print(f"Meet error with {img_path}, please generate again.")
216
+ score = random.randint(1, 5)
217
+
218
+ path_score_dict[img_path] = score
219
+
220
+ return path_score_dict
221
+
222
+
223
+ random.seed(1)
224
+ MODEL_PATH = ""
225
+ device = torch.device("cuda:3") if torch.cuda.is_available() else torch.device("cpu")
226
+
227
+ model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
228
+ MODEL_PATH,
229
+ torch_dtype=torch.bfloat16,
230
+ attn_implementation="flash_attention_2",
231
+ device_map=device,
232
+ )
233
+ processor = AutoProcessor.from_pretrained(MODEL_PATH)
234
+ processor.tokenizer.padding_side = "left"
235
+
236
+ image_root = ""
237
+ image_paths = get_image_paths(image_root) # It should be a list
238
+
239
+ path_score_dict = score_batch_image(
240
+ image_paths, model, processor
241
+ )
242
+
243
+ file_name = "output.txt"
244
+ with open(file_name, "w") as file:
245
+ for key, value in path_score_dict.items():
246
+ file.write(f"{key} {value}\n")
247
+
248
+ print("Done!")
249
+ ```
250
+ </details>
251
+
252
+ ### Thinking mode for inference
253
+
254
+ <details>
255
+ <summary>Example Code (VisualQuality-R1: Single Image Quality Rating with thinking)</summary>
256
+
257
+ ```python
258
+ from transformers import Qwen2_5_VLForConditionalGeneration, AutoTokenizer, AutoProcessor
259
+ from qwen_vl_utils import process_vision_info
260
+
261
+ import torch
262
+ import random
263
+ import re
264
+ import os
265
+
266
+
267
+ def score_image(image_path, model, processor):
268
+ PROMPT = (
269
+ "You are doing the image quality assessment task. Here is the question: "
270
+ "What is your overall rating on the quality of this picture? The rating should be a float between 1 and 5, "
271
+ "rounded to two decimal places, with 1 representing very poor quality and 5 representing excellent quality. "
272
+ "First output the thinking process in <think> </think> tags and then output the final answer with only one score in <answer> </answer> tags."
273
+ )
274
 
275
  QUESTION_TEMPLATE = "{Question} First output the thinking process in <think> </think> tags and then output the final answer with only one score in <answer> </answer> tags."
276
+ # QUESTION_TEMPLATE = "Please describe the quality of this image."
277
  message = [
278
  {
279
  "role": "user",
280
  "content": [
281
+ {'type': 'image', 'image': image_path},
282
+ {"type": "text", "text": PROMPT}
283
  ],
284
  }
285
  ]
 
299
  inputs = inputs.to(device)
300
 
301
  # Inference: Generation of the output
302
+ generated_ids = model.generate(**inputs, use_cache=True, max_new_tokens=2048, do_sample=True, top_k=50, top_p=1)
303
  generated_ids_trimmed = [
304
  out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
305
  ]
 
310
  reasoning = re.findall(r'<think>(.*?)</think>', batch_output_text[0], re.DOTALL)
311
  reasoning = reasoning[-1].strip()
312
 
313
+ try:
314
+ model_output_matches = re.findall(r'<answer>(.*?)</answer>', batch_output_text[0], re.DOTALL)
315
+ model_answer = model_output_matches[-1].strip() if model_output_matches else batch_output_text[0].strip()
316
+ score = float(re.search(r'\d+(\.\d+)?', model_answer).group())
317
+ except:
318
+ print(f"================= Meet error with {img_path}, please generate again. =================")
319
+ score = random.randint(1, 5)
320
 
321
  return reasoning, score
322
 
323
 
324
+ random.seed(1)
325
+ MODEL_PATH = ""
326
+ device = torch.device("cuda:5") if torch.cuda.is_available() else torch.device("cpu")
 
 
327
  image_path = ""
328
 
329
+ model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
330
+ MODEL_PATH,
331
+ torch_dtype=torch.bfloat16,
332
+ attn_implementation="flash_attention_2",
333
+ device_map=device,
334
+ )
335
+ processor = AutoProcessor.from_pretrained(MODEL_PATH)
336
+ processor.tokenizer.padding_side = "left"
337
+
338
  reasoning, score = score_image(
339
+ image_path, model, processor
 
340
  )
341
 
342
  print(reasoning)
343
  print(score)
344
+ ```
345
+ </details>
346
+
347
+
348
+ <details>
349
+ <summary>Example Code (VisualQuality-R1: Batch Images Quality Rating with thinking)</summary>
350
+
351
+ ```python
352
+ from transformers import Qwen2_5_VLForConditionalGeneration, AutoTokenizer, AutoProcessor
353
+ from qwen_vl_utils import process_vision_info
354
+ from tqdm import tqdm
355
+
356
+ import torch
357
+ import random
358
+ import re
359
+ import os
360
+
361
+
362
+ def get_image_paths(folder_path):
363
+ image_extensions = {'.jpg', '.jpeg', '.png', '.bmp', '.gif', '.tiff', '.webp'}
364
+ image_paths = []
365
+
366
+ for root, dirs, files in os.walk(folder_path):
367
+ for file in files:
368
+ _, ext = os.path.splitext(file)
369
+ if ext.lower() in image_extensions:
370
+ image_paths.append(os.path.join(root, file))
371
+
372
+ return image_paths
373
+
374
+ def score_batch_image(image_paths, model, processor):
375
+ PROMPT = (
376
+ "You are doing the image quality assessment task. Here is the question: "
377
+ "What is your overall rating on the quality of this picture? The rating should be a float between 1 and 5, "
378
+ "rounded to two decimal places, with 1 representing very poor quality and 5 representing excellent quality."
379
+ )
380
+
381
+ QUESTION_TEMPLATE = "{Question} First output the thinking process in <think> </think> tags and then output the final answer with only one score in <answer> </answer> tags."
382
+
383
+ messages = []
384
+ for img_path in image_paths:
385
+ message = [
386
+ {
387
+ "role": "user",
388
+ "content": [
389
+ {'type': 'image', 'image': img_path},
390
+ {"type": "text", "text": QUESTION_TEMPLATE.format(Question=PROMPT)}
391
+ ],
392
+ }
393
+ ]
394
+ messages.append(message)
395
+
396
+ BSZ = 32
397
+ all_outputs = [] # List to store all answers
398
+ for i in tqdm(range(0, len(messages), BSZ)):
399
+ batch_messages = messages[i:i + BSZ]
400
+
401
+ # Preparation for inference
402
+ text = [processor.apply_chat_template(msg, tokenize=False, add_generation_prompt=True, add_vision_id=True) for msg in batch_messages]
403
+
404
+ image_inputs, video_inputs = process_vision_info(batch_messages)
405
+ inputs = processor(
406
+ text=text,
407
+ images=image_inputs,
408
+ videos=video_inputs,
409
+ padding=True,
410
+ return_tensors="pt",
411
+ )
412
+ inputs = inputs.to(device)
413
+
414
+ # Inference: Generation of the output
415
+ generated_ids = model.generate(**inputs, use_cache=True, max_new_tokens=512, do_sample=True, top_k=50, top_p=1)
416
+ generated_ids_trimmed = [
417
+ out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
418
+ ]
419
+ batch_output_text = processor.batch_decode(
420
+ generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
421
+ )
422
+
423
+ all_outputs.extend(batch_output_text)
424
+
425
+ path_score_dict = {}
426
+ for img_path, model_output in zip(image_paths, all_outputs):
427
+ reasoning = re.findall(r'<think>(.*?)</think>', model_output, re.DOTALL)
428
+ reasoning = reasoning[-1].strip()
429
+
430
+ try:
431
+ model_output_matches = re.findall(r'<answer>(.*?)</answer>', model_output, re.DOTALL)
432
+ model_answer = model_output_matches[-1].strip() if model_output_matches else model_output.strip()
433
+ score = float(re.search(r'\d+(\.\d+)?', model_answer).group())
434
+ except:
435
+ print(f"Meet error with {img_path}, please generate again.")
436
+ score = random.randint(1, 5)
437
+
438
+ path_score_dict[img_path] = score
439
+
440
+ return path_score_dict
441
+
442
+
443
+ random.seed(1)
444
+ MODEL_PATH = ""
445
+ device = torch.device("cuda:3") if torch.cuda.is_available() else torch.device("cpu")
446
+
447
+ model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
448
+ MODEL_PATH,
449
+ torch_dtype=torch.bfloat16,
450
+ attn_implementation="flash_attention_2",
451
+ device_map=device,
452
+ )
453
+ processor = AutoProcessor.from_pretrained(MODEL_PATH)
454
+ processor.tokenizer.padding_side = "left"
455
+
456
+ image_root = ""
457
+ image_paths = get_image_paths(image_root) # It should be a list
458
+
459
+ path_score_dict = score_batch_image(
460
+ image_paths, model, processor
461
+ )
462
+
463
+ file_name = "output.txt"
464
+ with open(file_name, "w") as file:
465
+ for key, value in path_score_dict.items():
466
+ file.write(f"{key} {value}\n")
467
+
468
+ print("Done!")
469
+ ```
470
+ </details>
471
+
472
+
473
+ ## 🚀 Updated: VisualQuality-R1 high efficiency inference script with vLLM
474
+
475
+ <details>
476
+ <summary>Example Code (VisualQuality-R1: Batch Images Quality Rating with thinking, using vLLM)</summary>
477
+
478
+ ```python
479
+ # Please install vLLM first: https://docs.vllm.ai/en/stable/getting_started/installation/gpu.html
480
+
481
+ from transformers import Qwen2_5_VLProcessor, AutoProcessor
482
+ from vllm import LLM, RequestOutput, SamplingParams
483
+ from qwen_vl_utils import process_vision_info
484
+
485
+ import torch
486
+ import random
487
+ import re
488
+ import os
489
+
490
+ IMAGE_PATH = "./images"
491
+ MODEL_PATH = "TianheWu/VisualQuality-R1-7B"
492
+
493
+ def get_image_paths(folder_path):
494
+ image_extensions = {'.jpg', '.jpeg', '.png', '.bmp', '.gif', '.tiff', '.webp'}
495
+ image_paths = []
496
+
497
+ for root, dirs, files in os.walk(folder_path):
498
+ for file in files:
499
+ _, ext = os.path.splitext(file)
500
+ if ext.lower() in image_extensions:
501
+ image_paths.append(os.path.join(root, file))
502
+
503
+ return image_paths
504
+
505
+ def score_batch_image(image_paths, model: LLM, processor: Qwen2_5_VLProcessor):
506
+ PROMPT = (
507
+ "You are doing the image quality assessment task. Here is the question: "
508
+ "What is your overall rating on the quality of this picture? The rating should be a float between 1 and 5, "
509
+ "rounded to two decimal places, with 1 representing very poor quality and 5 representing excellent quality."
510
+ )
511
+
512
+ QUESTION_TEMPLATE = "{Question} First output the thinking process in <think> </think> tags and then output the final answer with only one score in <answer> </answer> tags."
513
+
514
+ messages = []
515
+ for img_path in image_paths:
516
+ message = [
517
+ {
518
+ "role": "user",
519
+ "content": [
520
+ {'type': 'image', 'image': img_path},
521
+ {"type": "text", "text": QUESTION_TEMPLATE.format(Question=PROMPT)}
522
+ ],
523
+ }
524
+ ]
525
+ messages.append(message)
526
+
527
+ all_outputs = [] # List to store all answers
528
+
529
+ # Preparation for inference
530
+ print("preprocessing ...")
531
+ texts = [processor.apply_chat_template(msg, tokenize=False, add_generation_prompt=True, add_vision_id=True) for msg in messages]
532
+ image_inputs, video_inputs = process_vision_info(messages)
533
+
534
+ inputs = [{
535
+ "prompt": texts[i],
536
+ "multi_modal_data": {
537
+ "image": image_inputs[i]
538
+ },
539
+ } for i in range(len(messages))]
540
+
541
+ output: list[RequestOutput] = model.generate(
542
+ inputs,
543
+ sampling_params=SamplingParams(
544
+ max_tokens=512,
545
+ temperature=0.1,
546
+ top_k=50,
547
+ top_p=1.0,
548
+ stop_token_ids=[processor.tokenizer.eos_token_id],
549
+ ),
550
+ )
551
+
552
+ batch_output_text = [o.outputs[0].text for o in output]
553
+
554
+ all_outputs.extend(batch_output_text)
555
+
556
+ path_score_dict = {}
557
+ for img_path, model_output in zip(image_paths, all_outputs):
558
+ print(f"{model_output = }")
559
+ try:
560
+ model_output_matches = re.findall(r'<answer>(.*?)</answer>', model_output, re.DOTALL)
561
+ model_answer = model_output_matches[-1].strip() if model_output_matches else model_output.strip()
562
+ score = float(re.search(r'\d+(\.\d+)?', model_answer).group())
563
+ except:
564
+ print(f"Meet error with {img_path}, please generate again.")
565
+ score = random.randint(1, 5)
566
+
567
+ path_score_dict[img_path] = score
568
+
569
+ return path_score_dict
570
+
571
+
572
+ random.seed(1)
573
+ model = LLM(
574
+ model=MODEL_PATH,
575
+ tensor_parallel_size=1,
576
+ trust_remote_code=True,
577
+ seed=1,
578
+ )
579
+
580
+ processor = AutoProcessor.from_pretrained(MODEL_PATH)
581
+ processor.tokenizer.padding_side = "left"
582
+
583
+ image_paths = get_image_paths(IMAGE_PATH) # It should be a list
584
+
585
+ path_score_dict = score_batch_image(
586
+ image_paths, model, processor
587
+ )
588
+
589
+ file_name = "output.txt"
590
+ with open(file_name, "w") as file:
591
+ for key, value in path_score_dict.items():
592
+ file.write(f"{key} {value}\n")
593
+
594
+ print("Done!")
595
+ ```
596
+ </details>