import requests | |
import torch | |
from PIL import Image | |
from transformers import MllamaForConditionalGeneration, AutoProcessor | |
model_id = "Model/Llama-3.2-90B-Vision-Instruct" | |
model = MllamaForConditionalGeneration.from_pretrained( | |
model_id, | |
torch_dtype=torch.bfloat16, | |
device_map="auto", | |
) | |
processor = AutoProcessor.from_pretrained(model_id) | |
# url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/0052a70beed5bf71b92610a43a52df6d286cd5f3/diffusers/rabbit.jpg" | |
# image = Image.open(requests.get(url, stream=True).raw) | |
temp = '/inspire/hdd/ws-ba572160-47f8-4ca1-984e-d6bcdeb95dbb/a100-maybe/albus/DataSet/10_1.png' | |
image = Image.open(temp) | |
import json | |
import pprint | |
from tqdm import tqdm | |
import json | |
import argparse | |
def read_json(file_path): | |
with open(file_path, 'r', encoding='utf-8') as file: | |
data = json.load(file) | |
return data | |
def write_json(file_path, data): | |
with open(file_path, 'w', encoding='utf-8') as file: | |
json.dump(data, file, ensure_ascii=False, indent=4) | |
# data = read_json("/inspire/hdd/ws-ba572160-47f8-4ca1-984e-d6bcdeb95dbb/a100-maybe/albus/DataSet/MiniCPM-V/all_blip_train_llava_coco_layout_caption_s1s3.json") | |
# data = read_json("/inspire/hdd/ws-ba572160-47f8-4ca1-984e-d6bcdeb95dbb/a100-maybe/albus/DataSet/Json/all_blip_train_llava_coco_layout_all_test.json") | |
# data = read_json("/inspire/hdd/ws-ba572160-47f8-4ca1-984e-d6bcdeb95dbb/a100-maybe/albus/DataSet/Json/all_blip_train_llava_coco_layout_all_train.json") | |
# data = read_json("/inspire/hdd/ws-ba572160-47f8-4ca1-984e-d6bcdeb95dbb/a100-maybe/albus/DataSet/AITM/AITM_Train_ALL_BBox_V0_Half.json") | |
# temp = '/inspire/hdd/ws-ba572160-47f8-4ca1-984e-d6bcdeb95dbb/a100-maybe/albus/DataSet/all_blip_train_llava_coco_layout_all_train_AITM_0.json' | |
# temp = '/inspire/hdd/ws-ba572160-47f8-4ca1-984e-d6bcdeb95dbb/a100-maybe/albus/DataSet/all_blip_train_llava_coco_layout_all_train_AITM_1.json' | |
# temp = '/inspire/hdd/ws-ba572160-47f8-4ca1-984e-d6bcdeb95dbb/a100-maybe/albus/DataSet/all_blip_train_llava_coco_layout_all_train_AITM_2.json' | |
# temp = '/inspire/hdd/ws-ba572160-47f8-4ca1-984e-d6bcdeb95dbb/a100-maybe/albus/DataSet/all_blip_train_llava_coco_layout_all_train_AITM_3.json' | |
# temp = '/inspire/hdd/ws-ba572160-47f8-4ca1-984e-d6bcdeb95dbb/a100-maybe/albus/DataSet/all_blip_train_llava_coco_layout_all_train_AITM_4.json' | |
# temp = '/inspire/hdd/ws-ba572160-47f8-4ca1-984e-d6bcdeb95dbb/a100-maybe/albus/DataSet/AITM_Json/AITM_Test_ALL_V0_down.json' | |
# temp = '/inspire/hdd/ws-ba572160-47f8-4ca1-984e-d6bcdeb95dbb/a100-maybe/albus/DataSet/Json/all_blip_test_llava_coco_layout_all_bbox_v3.json' | |
temp = '/inspire/hdd/ws-ba572160-47f8-4ca1-984e-d6bcdeb95dbb/a100-maybe/albus/DataSet/AITM_Json/all_blip_test_llava_coco_layout_AITM_0.json' | |
# temp = '/inspire/hdd/ws-ba572160-47f8-4ca1-984e-d6bcdeb95dbb/a100-maybe/albus/DataSet/AITM_Json/AITM_Train_ALL_BBox_V0.json' | |
data = read_json(temp) | |
parser = argparse.ArgumentParser(description="Process a dataset with specific index range.") | |
parser.add_argument("--index", type=int, required=True, help="Starting index (inclusive).") | |
args = parser.parse_args() | |
index = args.index | |
gap = len(data) | |
save_path = '/inspire/hdd/ws-ba572160-47f8-4ca1-984e-d6bcdeb95dbb/a100-maybe/albus/DataSet/AITM/AITM_Test_ALL_BBox_New_CapCoT_' + str(index) + '.json' | |
# gap = len(data) | |
# save_path = '/inspire/hdd/ws-ba572160-47f8-4ca1-984e-d6bcdeb95dbb/a100-maybe/albus/DataSet/AITM/AITM_Train_ALL_BBox_V0_Cap_' + str(index) + '.json' | |
# gap = 500 | |
# begin = index * gap | |
# save_path = 'DataSet/all_blip_train_llava_coco_layout_all_train_AITM_' + str(index) + '.json' | |
# save_path = 'DataSet/all_blip_train_llava_coco_layout_all_train_AITM_standby' + str(index) + '.json' | |
# save_path = '/home/ma-user/work/albus/DataSet/all_blip_train_llava_coco_layout_all_train_AITM_WLCB' + str(index) + '.json' | |
# begin = (index+1)*gap - 2500 | |
# save_path = 'DataSet/all_blip_train_llava_coco_layout_all_train_AITM_WLCB' + str(index) + '.json' | |
begin = index * gap | |
end = (index+1)*gap | |
counter = 0 | |
batch_size = 10 | |
# for idx, i in enumerate(tqdm(data[begin:end])): | |
for batch_idx in tqdm(range(begin, end, batch_size)): | |
batch = data[batch_idx:batch_idx + batch_size] | |
image_list = [] | |
input_text_list = [] | |
# while True: | |
for idx, i in enumerate(batch): | |
# caption_tag = False | |
# if '90B_caption' in i: | |
# if 'no image' in i['90B_caption'] or 'no diagram' in i['90B_caption'] or 'don\'t see ' in i['90B_caption'] or 'didn\'t provide' in i['90B_caption']: | |
# caption_tag = True | |
# else: | |
# caption_tag = True | |
# if caption_tag == False: | |
# continue | |
if '90B_caption' in i: | |
continue | |
# sent1 = i['caption'] | |
sent2 = i['action_target'] | |
goal = i['ori_question'].split('Goal:')[1] | |
action_target = i['action_target'] | |
path_base = '/inspire/hdd/ws-ba572160-47f8-4ca1-984e-d6bcdeb95dbb/a100-maybe/albus/DataSet/LLaVA-AiTW/' | |
temp = path_base + i['image'] | |
image = Image.open(temp) | |
# 第一个任务 | |
# prompt = " Describe the image in detail, including the main objects, their colors, positions, and relationships, as well as the background and any visible text. Highlight any actions, interactions, or notable details in a clear and concise manner. " | |
prompt = " Provide a brief description of the image, including the main elements, their positions and relationships, as well as the background and any visible text, expressed clearly and concisely. " | |
messages = [ | |
{"role": "user", "content": [ | |
{"type": "image"}, | |
{"type": "text", "text": prompt } | |
]} | |
] | |
input_text = processor.apply_chat_template(messages, add_generation_prompt=True) | |
image_list.append(image) | |
input_text_list.append(input_text) | |
if len(image_list) == 0: | |
continue | |
inputs = processor( | |
image_list, | |
input_text_list, | |
add_special_tokens=False, | |
return_tensors="pt", | |
padding=True, | |
).to(model.device) | |
output = model.generate(**inputs, max_new_tokens=512) | |
for idx, i in enumerate(batch): | |
i['90B_caption'] = processor.decode(output[idx]) | |
################################################################################################################## | |
image_list = [] | |
input_text_list = [] | |
for idx, i in enumerate(batch): | |
if '90B_CoT' in i: | |
continue | |
# sent1 = i['caption'] | |
# sent2 = i['action_target'] | |
goal = i['ori_question'].split('Goal:')[1] | |
action_target = i['action_target'] | |
path_base = '/inspire/hdd/ws-ba572160-47f8-4ca1-984e-d6bcdeb95dbb/a100-maybe/albus/DataSet/LLaVA-AiTW/' | |
temp = path_base + i['image'] | |
image = Image.open(temp) | |
# 第二个任务 | |
prompt = " The goal is : " + goal + " The target element is : " + action_target + " ###### Then analyze what's in the image and reason about that the target element of the image you should interact with in this step. " | |
messages = [ | |
{"role": "user", "content": [ | |
{"type": "image"}, | |
{"type": "text", "text": prompt } | |
]} | |
] | |
input_text = processor.apply_chat_template(messages, add_generation_prompt=True) | |
image_list.append(image) | |
input_text_list.append(input_text) | |
if len(image_list) == 0: | |
continue | |
inputs = processor( | |
image_list, | |
input_text_list, | |
add_special_tokens=False, | |
return_tensors="pt", | |
padding=True, | |
).to(model.device) | |
output = model.generate(**inputs, max_new_tokens=512) | |
for idx, i in enumerate(batch): | |
i['90B_CoT'] = processor.decode(output[idx]) | |
################################################################################################################## | |
# 每20次保存一次 | |
counter += 1 | |
if counter % 100 == 0: | |
print(f"Saving data at iteration {idx + 1}") | |
write_json(save_path, data) | |
# messages = [ | |
# {"role": "user", "content": [ | |
# {"type": "image"}, | |
# {"type": "text", "text": "Detailed description of the content in the image and the location of the elements that can be interacted with. The position information can be the scale of the center point of the interactable element in the image with the upper left corner as the origin (0, 0). The scale of the image is (width, height). The unit of the position information is the percentage of the width and height of the image. For example, if the image is 800*400, the position of the upper left corner is (0, 0), and the position of the lower right corner is (100, 100). The position of the center of the image is (50, 50). Such as, the location of Search bar is at (20,60) . "} | |
# ]} | |
# ] | |
# input_text = processor.apply_chat_template(messages, add_generation_prompt=True) | |
# inputs = processor( | |
# image, | |
# input_text, | |
# add_special_tokens=False, | |
# return_tensors="pt", | |
# ).to(model.device) | |
# output = model.generate(**inputs, max_new_tokens=512) | |
# print(processor.decode(output[0])) | |