File size: 4,030 Bytes
6f827db 10db1d0 ed5eaf2 e04cc25 e860fae d263f76 e04cc25 4e68b6b e04cc25 8e02a19 e04cc25 8e02a19 e04cc25 1a60170 e04cc25 b526ad5 e04cc25 efb81b8 655a711 e04cc25 adebee7 bd06209 e04cc25 8f77bac e04cc25 bd06209 e04cc25 adebee7 e04cc25 949fea9 e04cc25 adebee7 e04cc25 e41fda8 bd06209 adebee7 bd06209 e04cc25 adebee7 d81790e adebee7 bd06209 08b7506 bd06209 e04cc25 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 |
import os
import torch
from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor
from qwen_vl_utils import process_vision_info
import json
from tqdm import tqdm
import os
def read_json(file_path):
with open(file_path, 'r', encoding='utf-8') as file:
data = json.load(file)
return data
def write_json(file_path, data):
with open(file_path, 'w', encoding='utf-8') as file:
json.dump(data, file, ensure_ascii=False, indent=4)
# default: Load the model on the available device(s)
print(torch.cuda.device_count())
model_path = "/home/zbz5349/WorkSpace/aigeeks/Qwen2.5-VL/ckpt"
# model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
# model_path, torch_dtype="auto", device_map="auto"
# )
# We recommend enabling flash_attention_2 for better acceleration and memory saving, especially in multi-image and video scenarios.
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
model_path,
torch_dtype=torch.bfloat16,
attn_implementation="flash_attention_2",
device_map="auto",
)
# default processor
processor = AutoProcessor.from_pretrained(model_path)
print(model.device)
data = read_json('/home/zbz5349/WorkSpace/aigeeks/Qwen2.5-VL/LLaMA-Factory/data/Percption.json')
save_data = []
correct_num = 0
begin = 0
end = len(data)
batch_size = 1
for batch_idx in tqdm(range(begin, end, batch_size)):
batch = data[batch_idx:batch_idx + batch_size]
image_list = []
input_text_list = []
data_list = []
save_list = []
sd_ans = []
# while True:
for idx, i in enumerate(batch):
save_ = {
"role": "user",
"content": [
{
"type": "video",
"video": "file:///path/to/video1.mp4",
"max_pixels": 360 * 420,
"fps": 1.0,
},
{"type": "image", "image": "file:///path/to/image2.jpg"},
{"type": "text", "text": "Describe this video."},
],
"answer":"None",
"result":"None",
}
messages = {
"role": "user",
"content": [
{
"type": "video",
"video": "file:///path/to/video1.mp4",
"max_pixels": 360 * 420,
"fps": 1.0,
},
{"type": "image", "image": "file:///path/to/image2.jpg"},
{"type": "text", "text": "Describe this video."},
],
}
video_path = i['videos']
image_path = i['images']
question = i['messages'][0]['content']
answer = i['messages'][1]['content']
messages['content'][0]['video'] = video_path
messages['content'][1]['image'] = image_path
messages['content'][2]['text'] = question
save_['content'][0]['video'] = video_path
save_['content'][1]['image'] = image_path
save_['content'][2]['text'] = question
save_['answer'] = answer
sd_ans.append(answer)
data_list.append(messages)
save_list.append(save_)
text = processor.apply_chat_template(data_list, tokenize=False, add_generation_prompt=True)
image_inputs, video_inputs, video_kwargs = process_vision_info(data_list, return_video_kwargs=True)
fps = 1
inputs = processor(
text=[text],
images=image_inputs,
videos=video_inputs,
padding=True,
return_tensors="pt",
**video_kwargs,
)
inputs = inputs.to(model.device)
# Inference
generated_ids = model.generate(**inputs, max_new_tokens=128)
generated_ids_trimmed = [
out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
]
output_text = processor.batch_decode(
generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
)
for idx,x in enumerate(output_text):
save_list[idx]['result'] = x
save_data.append(save_list[idx])
print("correct_num", correct_num)
write_json("infer_answer_ori.json",save_data)
|