Wendy commited on
Commit
88ee19f
·
verified ·
1 Parent(s): 09d9fbd

Upload LLaMA_90B_infer_batch.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. LLaMA_90B_infer_batch.py +243 -0
LLaMA_90B_infer_batch.py ADDED
@@ -0,0 +1,243 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import requests
3
+ import torch
4
+ from PIL import Image
5
+ from transformers import MllamaForConditionalGeneration, AutoProcessor
6
+
7
+ model_id = "Model/Llama-3.2-90B-Vision-Instruct"
8
+
9
+ model = MllamaForConditionalGeneration.from_pretrained(
10
+ model_id,
11
+ torch_dtype=torch.bfloat16,
12
+ device_map="auto",
13
+ )
14
+ processor = AutoProcessor.from_pretrained(model_id)
15
+
16
+ # url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/0052a70beed5bf71b92610a43a52df6d286cd5f3/diffusers/rabbit.jpg"
17
+ # image = Image.open(requests.get(url, stream=True).raw)
18
+ temp = '/inspire/hdd/ws-ba572160-47f8-4ca1-984e-d6bcdeb95dbb/a100-maybe/albus/DataSet/10_1.png'
19
+ image = Image.open(temp)
20
+
21
+
22
+
23
+
24
+ import json
25
+ import pprint
26
+ from tqdm import tqdm
27
+ import json
28
+ import argparse
29
+
30
+
31
+
32
+ def read_json(file_path):
33
+ with open(file_path, 'r', encoding='utf-8') as file:
34
+ data = json.load(file)
35
+ return data
36
+
37
+ def write_json(file_path, data):
38
+ with open(file_path, 'w', encoding='utf-8') as file:
39
+ json.dump(data, file, ensure_ascii=False, indent=4)
40
+
41
+ # data = read_json("/inspire/hdd/ws-ba572160-47f8-4ca1-984e-d6bcdeb95dbb/a100-maybe/albus/DataSet/MiniCPM-V/all_blip_train_llava_coco_layout_caption_s1s3.json")
42
+ # data = read_json("/inspire/hdd/ws-ba572160-47f8-4ca1-984e-d6bcdeb95dbb/a100-maybe/albus/DataSet/Json/all_blip_train_llava_coco_layout_all_test.json")
43
+ # data = read_json("/inspire/hdd/ws-ba572160-47f8-4ca1-984e-d6bcdeb95dbb/a100-maybe/albus/DataSet/Json/all_blip_train_llava_coco_layout_all_train.json")
44
+ # data = read_json("/inspire/hdd/ws-ba572160-47f8-4ca1-984e-d6bcdeb95dbb/a100-maybe/albus/DataSet/AITM/AITM_Train_ALL_BBox_V0_Half.json")
45
+
46
+ # temp = '/inspire/hdd/ws-ba572160-47f8-4ca1-984e-d6bcdeb95dbb/a100-maybe/albus/DataSet/all_blip_train_llava_coco_layout_all_train_AITM_0.json'
47
+ # temp = '/inspire/hdd/ws-ba572160-47f8-4ca1-984e-d6bcdeb95dbb/a100-maybe/albus/DataSet/all_blip_train_llava_coco_layout_all_train_AITM_1.json'
48
+ # temp = '/inspire/hdd/ws-ba572160-47f8-4ca1-984e-d6bcdeb95dbb/a100-maybe/albus/DataSet/all_blip_train_llava_coco_layout_all_train_AITM_2.json'
49
+ # temp = '/inspire/hdd/ws-ba572160-47f8-4ca1-984e-d6bcdeb95dbb/a100-maybe/albus/DataSet/all_blip_train_llava_coco_layout_all_train_AITM_3.json'
50
+ # temp = '/inspire/hdd/ws-ba572160-47f8-4ca1-984e-d6bcdeb95dbb/a100-maybe/albus/DataSet/all_blip_train_llava_coco_layout_all_train_AITM_4.json'
51
+ # temp = '/inspire/hdd/ws-ba572160-47f8-4ca1-984e-d6bcdeb95dbb/a100-maybe/albus/DataSet/AITM_Json/AITM_Test_ALL_V0_down.json'
52
+
53
+
54
+ # temp = '/inspire/hdd/ws-ba572160-47f8-4ca1-984e-d6bcdeb95dbb/a100-maybe/albus/DataSet/Json/all_blip_test_llava_coco_layout_all_bbox_v3.json'
55
+ temp = '/inspire/hdd/ws-ba572160-47f8-4ca1-984e-d6bcdeb95dbb/a100-maybe/albus/DataSet/AITM_Json/all_blip_test_llava_coco_layout_AITM_0.json'
56
+ # temp = '/inspire/hdd/ws-ba572160-47f8-4ca1-984e-d6bcdeb95dbb/a100-maybe/albus/DataSet/AITM_Json/AITM_Train_ALL_BBox_V0.json'
57
+
58
+
59
+ data = read_json(temp)
60
+
61
+ parser = argparse.ArgumentParser(description="Process a dataset with specific index range.")
62
+ parser.add_argument("--index", type=int, required=True, help="Starting index (inclusive).")
63
+ args = parser.parse_args()
64
+
65
+
66
+ index = args.index
67
+ gap = len(data)
68
+ save_path = '/inspire/hdd/ws-ba572160-47f8-4ca1-984e-d6bcdeb95dbb/a100-maybe/albus/DataSet/AITM/AITM_Test_ALL_BBox_New_CapCoT_' + str(index) + '.json'
69
+
70
+ # gap = len(data)
71
+ # save_path = '/inspire/hdd/ws-ba572160-47f8-4ca1-984e-d6bcdeb95dbb/a100-maybe/albus/DataSet/AITM/AITM_Train_ALL_BBox_V0_Cap_' + str(index) + '.json'
72
+
73
+
74
+
75
+ # gap = 500
76
+ # begin = index * gap
77
+ # save_path = 'DataSet/all_blip_train_llava_coco_layout_all_train_AITM_' + str(index) + '.json'
78
+ # save_path = 'DataSet/all_blip_train_llava_coco_layout_all_train_AITM_standby' + str(index) + '.json'
79
+ # save_path = '/home/ma-user/work/albus/DataSet/all_blip_train_llava_coco_layout_all_train_AITM_WLCB' + str(index) + '.json'
80
+
81
+
82
+ # begin = (index+1)*gap - 2500
83
+ # save_path = 'DataSet/all_blip_train_llava_coco_layout_all_train_AITM_WLCB' + str(index) + '.json'
84
+
85
+ begin = index * gap
86
+ end = (index+1)*gap
87
+
88
+ counter = 0
89
+ batch_size = 10
90
+ # for idx, i in enumerate(tqdm(data[begin:end])):
91
+
92
+ for batch_idx in tqdm(range(begin, end, batch_size)):
93
+ batch = data[batch_idx:batch_idx + batch_size]
94
+
95
+ image_list = []
96
+ input_text_list = []
97
+
98
+ # while True:
99
+ for idx, i in enumerate(batch):
100
+
101
+
102
+ # caption_tag = False
103
+ # if '90B_caption' in i:
104
+ # if 'no image' in i['90B_caption'] or 'no diagram' in i['90B_caption'] or 'don\'t see ' in i['90B_caption'] or 'didn\'t provide' in i['90B_caption']:
105
+ # caption_tag = True
106
+ # else:
107
+ # caption_tag = True
108
+
109
+ # if caption_tag == False:
110
+ # continue
111
+
112
+
113
+ if '90B_caption' in i:
114
+ continue
115
+
116
+ # sent1 = i['caption']
117
+ sent2 = i['action_target']
118
+ goal = i['ori_question'].split('Goal:')[1]
119
+ action_target = i['action_target']
120
+
121
+ path_base = '/inspire/hdd/ws-ba572160-47f8-4ca1-984e-d6bcdeb95dbb/a100-maybe/albus/DataSet/LLaVA-AiTW/'
122
+ temp = path_base + i['image']
123
+ image = Image.open(temp)
124
+
125
+ # 第一个任务
126
+ # prompt = " Describe the image in detail, including the main objects, their colors, positions, and relationships, as well as the background and any visible text. Highlight any actions, interactions, or notable details in a clear and concise manner. "
127
+ prompt = " Provide a brief description of the image, including the main elements, their positions and relationships, as well as the background and any visible text, expressed clearly and concisely. "
128
+
129
+
130
+ messages = [
131
+ {"role": "user", "content": [
132
+ {"type": "image"},
133
+ {"type": "text", "text": prompt }
134
+ ]}
135
+ ]
136
+
137
+ input_text = processor.apply_chat_template(messages, add_generation_prompt=True)
138
+
139
+ image_list.append(image)
140
+ input_text_list.append(input_text)
141
+
142
+ if len(image_list) == 0:
143
+ continue
144
+
145
+
146
+ inputs = processor(
147
+ image_list,
148
+ input_text_list,
149
+ add_special_tokens=False,
150
+ return_tensors="pt",
151
+ padding=True,
152
+ ).to(model.device)
153
+
154
+ output = model.generate(**inputs, max_new_tokens=512)
155
+
156
+ for idx, i in enumerate(batch):
157
+ i['90B_caption'] = processor.decode(output[idx])
158
+
159
+ ##################################################################################################################
160
+
161
+ image_list = []
162
+ input_text_list = []
163
+ for idx, i in enumerate(batch):
164
+
165
+ if '90B_CoT' in i:
166
+ continue
167
+
168
+ # sent1 = i['caption']
169
+ # sent2 = i['action_target']
170
+ goal = i['ori_question'].split('Goal:')[1]
171
+ action_target = i['action_target']
172
+
173
+ path_base = '/inspire/hdd/ws-ba572160-47f8-4ca1-984e-d6bcdeb95dbb/a100-maybe/albus/DataSet/LLaVA-AiTW/'
174
+ temp = path_base + i['image']
175
+ image = Image.open(temp)
176
+
177
+ # 第二个任务
178
+ prompt = " The goal is : " + goal + " The target element is : " + action_target + " ###### Then analyze what's in the image and reason about that the target element of the image you should interact with in this step. "
179
+ messages = [
180
+ {"role": "user", "content": [
181
+ {"type": "image"},
182
+ {"type": "text", "text": prompt }
183
+ ]}
184
+ ]
185
+
186
+ input_text = processor.apply_chat_template(messages, add_generation_prompt=True)
187
+ image_list.append(image)
188
+ input_text_list.append(input_text)
189
+
190
+ if len(image_list) == 0:
191
+ continue
192
+
193
+ inputs = processor(
194
+ image_list,
195
+ input_text_list,
196
+ add_special_tokens=False,
197
+ return_tensors="pt",
198
+ padding=True,
199
+ ).to(model.device)
200
+
201
+ output = model.generate(**inputs, max_new_tokens=512)
202
+
203
+ for idx, i in enumerate(batch):
204
+ i['90B_CoT'] = processor.decode(output[idx])
205
+
206
+ ##################################################################################################################
207
+
208
+ # 每20次保存一次
209
+ counter += 1
210
+ if counter % 100 == 0:
211
+ print(f"Saving data at iteration {idx + 1}")
212
+ write_json(save_path, data)
213
+
214
+
215
+
216
+
217
+
218
+
219
+
220
+
221
+
222
+
223
+
224
+ # messages = [
225
+ # {"role": "user", "content": [
226
+ # {"type": "image"},
227
+ # {"type": "text", "text": "Detailed description of the content in the image and the location of the elements that can be interacted with. The position information can be the scale of the center point of the interactable element in the image with the upper left corner as the origin (0, 0). The scale of the image is (width, height). The unit of the position information is the percentage of the width and height of the image. For example, if the image is 800*400, the position of the upper left corner is (0, 0), and the position of the lower right corner is (100, 100). The position of the center of the image is (50, 50). Such as, the location of Search bar is at (20,60) . "}
228
+ # ]}
229
+ # ]
230
+
231
+
232
+
233
+ # input_text = processor.apply_chat_template(messages, add_generation_prompt=True)
234
+ # inputs = processor(
235
+ # image,
236
+ # input_text,
237
+ # add_special_tokens=False,
238
+ # return_tensors="pt",
239
+ # ).to(model.device)
240
+
241
+ # output = model.generate(**inputs, max_new_tokens=512)
242
+ # print(processor.decode(output[0]))
243
+