ACL-2025 / LLaMA_90B_infer_batch.py
Wendy
Upload LLaMA_90B_infer_batch.py with huggingface_hub
88ee19f verified
import requests
import torch
from PIL import Image
from transformers import MllamaForConditionalGeneration, AutoProcessor
model_id = "Model/Llama-3.2-90B-Vision-Instruct"
model = MllamaForConditionalGeneration.from_pretrained(
model_id,
torch_dtype=torch.bfloat16,
device_map="auto",
)
processor = AutoProcessor.from_pretrained(model_id)
# url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/0052a70beed5bf71b92610a43a52df6d286cd5f3/diffusers/rabbit.jpg"
# image = Image.open(requests.get(url, stream=True).raw)
temp = '/inspire/hdd/ws-ba572160-47f8-4ca1-984e-d6bcdeb95dbb/a100-maybe/albus/DataSet/10_1.png'
image = Image.open(temp)
import json
import pprint
from tqdm import tqdm
import json
import argparse
def read_json(file_path):
with open(file_path, 'r', encoding='utf-8') as file:
data = json.load(file)
return data
def write_json(file_path, data):
with open(file_path, 'w', encoding='utf-8') as file:
json.dump(data, file, ensure_ascii=False, indent=4)
# data = read_json("/inspire/hdd/ws-ba572160-47f8-4ca1-984e-d6bcdeb95dbb/a100-maybe/albus/DataSet/MiniCPM-V/all_blip_train_llava_coco_layout_caption_s1s3.json")
# data = read_json("/inspire/hdd/ws-ba572160-47f8-4ca1-984e-d6bcdeb95dbb/a100-maybe/albus/DataSet/Json/all_blip_train_llava_coco_layout_all_test.json")
# data = read_json("/inspire/hdd/ws-ba572160-47f8-4ca1-984e-d6bcdeb95dbb/a100-maybe/albus/DataSet/Json/all_blip_train_llava_coco_layout_all_train.json")
# data = read_json("/inspire/hdd/ws-ba572160-47f8-4ca1-984e-d6bcdeb95dbb/a100-maybe/albus/DataSet/AITM/AITM_Train_ALL_BBox_V0_Half.json")
# temp = '/inspire/hdd/ws-ba572160-47f8-4ca1-984e-d6bcdeb95dbb/a100-maybe/albus/DataSet/all_blip_train_llava_coco_layout_all_train_AITM_0.json'
# temp = '/inspire/hdd/ws-ba572160-47f8-4ca1-984e-d6bcdeb95dbb/a100-maybe/albus/DataSet/all_blip_train_llava_coco_layout_all_train_AITM_1.json'
# temp = '/inspire/hdd/ws-ba572160-47f8-4ca1-984e-d6bcdeb95dbb/a100-maybe/albus/DataSet/all_blip_train_llava_coco_layout_all_train_AITM_2.json'
# temp = '/inspire/hdd/ws-ba572160-47f8-4ca1-984e-d6bcdeb95dbb/a100-maybe/albus/DataSet/all_blip_train_llava_coco_layout_all_train_AITM_3.json'
# temp = '/inspire/hdd/ws-ba572160-47f8-4ca1-984e-d6bcdeb95dbb/a100-maybe/albus/DataSet/all_blip_train_llava_coco_layout_all_train_AITM_4.json'
# temp = '/inspire/hdd/ws-ba572160-47f8-4ca1-984e-d6bcdeb95dbb/a100-maybe/albus/DataSet/AITM_Json/AITM_Test_ALL_V0_down.json'
# temp = '/inspire/hdd/ws-ba572160-47f8-4ca1-984e-d6bcdeb95dbb/a100-maybe/albus/DataSet/Json/all_blip_test_llava_coco_layout_all_bbox_v3.json'
temp = '/inspire/hdd/ws-ba572160-47f8-4ca1-984e-d6bcdeb95dbb/a100-maybe/albus/DataSet/AITM_Json/all_blip_test_llava_coco_layout_AITM_0.json'
# temp = '/inspire/hdd/ws-ba572160-47f8-4ca1-984e-d6bcdeb95dbb/a100-maybe/albus/DataSet/AITM_Json/AITM_Train_ALL_BBox_V0.json'
data = read_json(temp)
parser = argparse.ArgumentParser(description="Process a dataset with specific index range.")
parser.add_argument("--index", type=int, required=True, help="Starting index (inclusive).")
args = parser.parse_args()
index = args.index
gap = len(data)
save_path = '/inspire/hdd/ws-ba572160-47f8-4ca1-984e-d6bcdeb95dbb/a100-maybe/albus/DataSet/AITM/AITM_Test_ALL_BBox_New_CapCoT_' + str(index) + '.json'
# gap = len(data)
# save_path = '/inspire/hdd/ws-ba572160-47f8-4ca1-984e-d6bcdeb95dbb/a100-maybe/albus/DataSet/AITM/AITM_Train_ALL_BBox_V0_Cap_' + str(index) + '.json'
# gap = 500
# begin = index * gap
# save_path = 'DataSet/all_blip_train_llava_coco_layout_all_train_AITM_' + str(index) + '.json'
# save_path = 'DataSet/all_blip_train_llava_coco_layout_all_train_AITM_standby' + str(index) + '.json'
# save_path = '/home/ma-user/work/albus/DataSet/all_blip_train_llava_coco_layout_all_train_AITM_WLCB' + str(index) + '.json'
# begin = (index+1)*gap - 2500
# save_path = 'DataSet/all_blip_train_llava_coco_layout_all_train_AITM_WLCB' + str(index) + '.json'
begin = index * gap
end = (index+1)*gap
counter = 0
batch_size = 10
# for idx, i in enumerate(tqdm(data[begin:end])):
for batch_idx in tqdm(range(begin, end, batch_size)):
batch = data[batch_idx:batch_idx + batch_size]
image_list = []
input_text_list = []
# while True:
for idx, i in enumerate(batch):
# caption_tag = False
# if '90B_caption' in i:
# if 'no image' in i['90B_caption'] or 'no diagram' in i['90B_caption'] or 'don\'t see ' in i['90B_caption'] or 'didn\'t provide' in i['90B_caption']:
# caption_tag = True
# else:
# caption_tag = True
# if caption_tag == False:
# continue
if '90B_caption' in i:
continue
# sent1 = i['caption']
sent2 = i['action_target']
goal = i['ori_question'].split('Goal:')[1]
action_target = i['action_target']
path_base = '/inspire/hdd/ws-ba572160-47f8-4ca1-984e-d6bcdeb95dbb/a100-maybe/albus/DataSet/LLaVA-AiTW/'
temp = path_base + i['image']
image = Image.open(temp)
# 第一个任务
# prompt = " Describe the image in detail, including the main objects, their colors, positions, and relationships, as well as the background and any visible text. Highlight any actions, interactions, or notable details in a clear and concise manner. "
prompt = " Provide a brief description of the image, including the main elements, their positions and relationships, as well as the background and any visible text, expressed clearly and concisely. "
messages = [
{"role": "user", "content": [
{"type": "image"},
{"type": "text", "text": prompt }
]}
]
input_text = processor.apply_chat_template(messages, add_generation_prompt=True)
image_list.append(image)
input_text_list.append(input_text)
if len(image_list) == 0:
continue
inputs = processor(
image_list,
input_text_list,
add_special_tokens=False,
return_tensors="pt",
padding=True,
).to(model.device)
output = model.generate(**inputs, max_new_tokens=512)
for idx, i in enumerate(batch):
i['90B_caption'] = processor.decode(output[idx])
##################################################################################################################
image_list = []
input_text_list = []
for idx, i in enumerate(batch):
if '90B_CoT' in i:
continue
# sent1 = i['caption']
# sent2 = i['action_target']
goal = i['ori_question'].split('Goal:')[1]
action_target = i['action_target']
path_base = '/inspire/hdd/ws-ba572160-47f8-4ca1-984e-d6bcdeb95dbb/a100-maybe/albus/DataSet/LLaVA-AiTW/'
temp = path_base + i['image']
image = Image.open(temp)
# 第二个任务
prompt = " The goal is : " + goal + " The target element is : " + action_target + " ###### Then analyze what's in the image and reason about that the target element of the image you should interact with in this step. "
messages = [
{"role": "user", "content": [
{"type": "image"},
{"type": "text", "text": prompt }
]}
]
input_text = processor.apply_chat_template(messages, add_generation_prompt=True)
image_list.append(image)
input_text_list.append(input_text)
if len(image_list) == 0:
continue
inputs = processor(
image_list,
input_text_list,
add_special_tokens=False,
return_tensors="pt",
padding=True,
).to(model.device)
output = model.generate(**inputs, max_new_tokens=512)
for idx, i in enumerate(batch):
i['90B_CoT'] = processor.decode(output[idx])
##################################################################################################################
# 每20次保存一次
counter += 1
if counter % 100 == 0:
print(f"Saving data at iteration {idx + 1}")
write_json(save_path, data)
# messages = [
# {"role": "user", "content": [
# {"type": "image"},
# {"type": "text", "text": "Detailed description of the content in the image and the location of the elements that can be interacted with. The position information can be the scale of the center point of the interactable element in the image with the upper left corner as the origin (0, 0). The scale of the image is (width, height). The unit of the position information is the percentage of the width and height of the image. For example, if the image is 800*400, the position of the upper left corner is (0, 0), and the position of the lower right corner is (100, 100). The position of the center of the image is (50, 50). Such as, the location of Search bar is at (20,60) . "}
# ]}
# ]
# input_text = processor.apply_chat_template(messages, add_generation_prompt=True)
# inputs = processor(
# image,
# input_text,
# add_special_tokens=False,
# return_tensors="pt",
# ).to(model.device)
# output = model.generate(**inputs, max_new_tokens=512)
# print(processor.decode(output[0]))