|
import os |
|
|
|
import torch |
|
from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor |
|
from qwen_vl_utils import process_vision_info |
|
import json |
|
from tqdm import tqdm |
|
import os |
|
|
|
def read_json(file_path): |
|
with open(file_path, 'r', encoding='utf-8') as file: |
|
data = json.load(file) |
|
return data |
|
|
|
def write_json(file_path, data): |
|
with open(file_path, 'w', encoding='utf-8') as file: |
|
json.dump(data, file, ensure_ascii=False, indent=4) |
|
|
|
|
|
print(torch.cuda.device_count()) |
|
model_path = "/home/zbz5349/WorkSpace/aigeeks/Qwen2.5-VL/ckpt" |
|
|
|
|
|
|
|
|
|
|
|
model = Qwen2_5_VLForConditionalGeneration.from_pretrained( |
|
model_path, |
|
torch_dtype=torch.bfloat16, |
|
attn_implementation="flash_attention_2", |
|
device_map="auto", |
|
) |
|
|
|
|
|
processor = AutoProcessor.from_pretrained(model_path) |
|
print(model.device) |
|
|
|
|
|
|
|
|
|
data = read_json('/home/zbz5349/WorkSpace/aigeeks/Qwen2.5-VL/LLaMA-Factory/data/Percption_.json') |
|
save_data = [] |
|
correct_num = 0 |
|
begin = 0 |
|
end = len(data) |
|
batch_size = 1 |
|
for batch_idx in tqdm(range(begin, end, batch_size)): |
|
batch = data[batch_idx:batch_idx + batch_size] |
|
|
|
image_list = [] |
|
input_text_list = [] |
|
data_list = [] |
|
save_list = [] |
|
sd_ans = [] |
|
|
|
for idx, i in enumerate(batch): |
|
save_ = { |
|
"role": "user", |
|
"content": [ |
|
{ |
|
"type": "video", |
|
"video": "file:///path/to/video1.mp4", |
|
"max_pixels": 360 * 420, |
|
"fps": 1.0, |
|
}, |
|
{"type": "text", "text": "Describe this video."}, |
|
], |
|
"answer":"None", |
|
"result":"None", |
|
} |
|
messages = { |
|
"role": "user", |
|
"content": [ |
|
{ |
|
"type": "video", |
|
"video": "file:///path/to/video1.mp4", |
|
"max_pixels": 360 * 420, |
|
"fps": 1.0, |
|
}, |
|
{"type": "text", "text": "Describe this video."}, |
|
], |
|
} |
|
|
|
|
|
video_path = i['videos'] |
|
question = i['messages'][0]['content'] |
|
answer = i['messages'][1]['content'] |
|
messages['content'][0]['video'] = video_path |
|
messages['content'][1]['text'] = question |
|
|
|
save_['content'][0]['video'] = video_path |
|
save_['content'][1]['text'] = question |
|
save_['answer'] = answer |
|
sd_ans.append(answer) |
|
data_list.append(messages) |
|
save_list.append(save_) |
|
|
|
text = processor.apply_chat_template(data_list, tokenize=False, add_generation_prompt=True) |
|
image_inputs, video_inputs, video_kwargs = process_vision_info(data_list, return_video_kwargs=True) |
|
fps = 1 |
|
inputs = processor( |
|
text=[text], |
|
images=image_inputs, |
|
videos=video_inputs, |
|
padding=True, |
|
return_tensors="pt", |
|
**video_kwargs, |
|
) |
|
inputs = inputs.to(model.device) |
|
|
|
|
|
generated_ids = model.generate(**inputs, max_new_tokens=128) |
|
generated_ids_trimmed = [ |
|
out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids) |
|
] |
|
output_text = processor.batch_decode( |
|
generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False |
|
) |
|
for idx,x in enumerate(output_text): |
|
save_list[idx]['result'] = x |
|
save_data.append(save_list[idx]) |
|
|
|
print("correct_num", correct_num) |
|
write_json("infer_answer_percption.json",save_data) |
|
|
|
|