Wendy commited on
Commit
d2468c2
·
verified ·
1 Parent(s): f33e794

Upload cogagent_infer_batch.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. cogagent_infer_batch.py +230 -0
cogagent_infer_batch.py ADDED
@@ -0,0 +1,230 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ import re
4
+ import torch
5
+ from PIL import Image, ImageDraw
6
+ from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
7
+ from typing import List
8
+ import json
9
+ from tqdm import tqdm
10
+ from torch.utils.data import Dataset, DataLoader
11
+
12
+ class AITM_Dataset(Dataset):
13
+ def __init__(self, json_path):
14
+ self.data = []
15
+ with open(json_path, 'r') as f:
16
+ for line in f:
17
+ self.data.append(json.loads(line))
18
+
19
+ def __len__(self):
20
+ return len(self.data)
21
+
22
+ def __getitem__(self, idx):
23
+ x = self.data[idx]
24
+ img_path = x['image']
25
+ task = x['conversations'][0]['value']
26
+ return img_path, task
27
+ def draw_boxes_on_image(image: Image.Image, boxes: List[List[float]], save_path: str):
28
+ """
29
+ Draws red bounding boxes on the given image and saves it.
30
+
31
+ Parameters:
32
+ - image (PIL.Image.Image): The image on which to draw the bounding boxes.
33
+ - boxes (List[List[float]]): A list of bounding boxes, each defined as [x_min, y_min, x_max, y_max].
34
+ Coordinates are expected to be normalized (0 to 1).
35
+ - save_path (str): The path to save the updated image.
36
+
37
+ Description:
38
+ Each box coordinate is a fraction of the image dimension. This function converts them to actual pixel
39
+ coordinates and draws a red rectangle to mark the area. The annotated image is then saved to the specified path.
40
+ """
41
+ draw = ImageDraw.Draw(image)
42
+ for box in boxes:
43
+ x_min = int(box[0] * image.width)
44
+ y_min = int(box[1] * image.height)
45
+ x_max = int(box[2] * image.width)
46
+ y_max = int(box[3] * image.height)
47
+ draw.rectangle([x_min, y_min, x_max, y_max], outline="red", width=3)
48
+ image.save(save_path)
49
+
50
+
51
+ def main():
52
+ """
53
+ A continuous interactive demo using the CogAgent1.5 model with selectable format prompts.
54
+ The output_image_path is interpreted as a directory. For each round of interaction,
55
+ the annotated image will be saved in the directory with the filename:
56
+ {original_image_name_without_extension}_{round_number}.png
57
+
58
+ Example:
59
+ python cli_demo.py --model_dir THUDM/cogagent-9b-20241220 --platform "Mac" --max_length 4096 --top_k 1 \
60
+ --output_image_path ./results --format_key status_action_op_sensitive
61
+ """
62
+
63
+ parser = argparse.ArgumentParser(
64
+ description="Continuous interactive demo with CogAgent model and selectable format."
65
+ )
66
+ parser.add_argument(
67
+ "--model_dir", required=True, help="Path or identifier of the model."
68
+ )
69
+ parser.add_argument(
70
+ "--platform",
71
+ default="Mac",
72
+ help="Platform information string (e.g., 'Mac', 'WIN').",
73
+ )
74
+ parser.add_argument(
75
+ "--max_length", type=int, default=4096, help="Maximum generation length."
76
+ )
77
+ parser.add_argument(
78
+ "--top_k", type=int, default=1, help="Top-k sampling parameter."
79
+ )
80
+ parser.add_argument(
81
+ "--output_image_path",
82
+ default="results",
83
+ help="Directory to save the annotated images.",
84
+ )
85
+ parser.add_argument(
86
+ "--input_json",
87
+ default="/Users/baixuehai/Downloads/2025_2/AITM_Test_General_BBox_v0.json",
88
+ help="Directory to save the annotated images.",
89
+ )
90
+ parser.add_argument(
91
+ "--output_json",
92
+ default="/Users/baixuehai/Downloads/2025_2/AITM_Test_General_BBox_v0.json",
93
+ help="Directory to save the annotated images.",
94
+ )
95
+ parser.add_argument(
96
+ "--format_key",
97
+ default="action_op_sensitive",
98
+ help="Key to select the prompt format.",
99
+ )
100
+ args = parser.parse_args()
101
+
102
+ # Dictionary mapping format keys to format strings
103
+ format_dict = {
104
+ "action_op_sensitive": "(Answer in Action-Operation-Sensitive format.)",
105
+ "status_plan_action_op": "(Answer in Status-Plan-Action-Operation format.)",
106
+ "status_action_op_sensitive": "(Answer in Status-Action-Operation-Sensitive format.)",
107
+ "status_action_op": "(Answer in Status-Action-Operation format.)",
108
+ "action_op": "(Answer in Action-Operation format.)",
109
+ }
110
+
111
+ # Ensure the provided format_key is valid
112
+ if args.format_key not in format_dict:
113
+ raise ValueError(
114
+ f"Invalid format_key. Available keys are: {list(format_dict.keys())}"
115
+ )
116
+
117
+ # Ensure the output directory exists
118
+ os.makedirs(args.output_image_path, exist_ok=True)
119
+
120
+ # Load the tokenizer and model
121
+ tokenizer = AutoTokenizer.from_pretrained(args.model_dir, trust_remote_code=True)
122
+ model = AutoModelForCausalLM.from_pretrained(
123
+ args.model_dir,
124
+ torch_dtype=torch.bfloat16,
125
+ trust_remote_code=True,
126
+ device_map="auto",
127
+ # quantization_config=BitsAndBytesConfig(load_in_8bit=True), # For INT8 quantization
128
+ # quantization_config=BitsAndBytesConfig(load_in_4bit=True), # For INT4 quantization
129
+ ).eval()
130
+ # Initialize platform and selected format strings
131
+ platform_str = f"(Platform: {args.platform})\n"
132
+ format_str = format_dict[args.format_key]
133
+
134
+ # Initialize history lists
135
+ history_step = []
136
+ history_action = []
137
+
138
+ round_num = 1
139
+ # with open(args.input_json, "r") as f:
140
+ # data = json.load(f)
141
+ dataset = AITM_Dataset(args.input_json)
142
+ data_loader = DataLoader(dataset, batch_size=16, shuffle=False)
143
+
144
+ res = []
145
+ for x in tqdm(dataloader,desc="Processing items"):
146
+ #x = data[i]
147
+ img_path,task = x
148
+ image = []
149
+ for path in img_path:
150
+ image = Image.open(path).convert("RGB")
151
+ #image = Image.open(img_path).convert("RGB")
152
+ #task = x['conversations'][0]['value']
153
+ # Verify history lengths match
154
+ if len(history_step) != len(history_action):
155
+ raise ValueError("Mismatch in lengths of history_step and history_action.")
156
+
157
+ # Format history steps for output
158
+ history_str = "\nHistory steps: "
159
+ for index, (step, action) in enumerate(zip(history_step, history_action)):
160
+ history_str += f"\n{index}. {step}\t{action}"
161
+
162
+ # Compose the query with task, platform, and selected format instructions
163
+ query = f"Task: {task}{history_str}\n{platform_str}{format_str}"
164
+
165
+ #print(f"Round {round_num} query:\n{query}")
166
+
167
+ inputs = tokenizer.apply_chat_template(
168
+ [{"role": "user", "image": image, "content": query}],
169
+ add_generation_prompt=True,
170
+ tokenize=True,
171
+ return_tensors="pt",
172
+ return_dict=True,
173
+ ).to(model.device)
174
+ # Generation parameters
175
+ gen_kwargs = {
176
+ "max_length": args.max_length,
177
+ "do_sample": True,
178
+ "top_k": args.top_k,
179
+ }
180
+
181
+ # Generate response
182
+ with torch.no_grad():
183
+ outputs = model.generate(**inputs, **gen_kwargs)
184
+ outputs = outputs[:, inputs["input_ids"].shape[1]:]
185
+ response = tokenizer.decode(outputs[0], skip_special_tokens=True)
186
+ #print(f"Model response:\n{response}")
187
+
188
+ # Extract grounded operation and action
189
+ grounded_pattern = r"Grounded Operation:\s*(.*)"
190
+ action_pattern = r"Action:\s*(.*)"
191
+ matches_history = re.search(grounded_pattern, response)
192
+ matches_actions = re.search(action_pattern, response)
193
+
194
+ if matches_history:
195
+ grounded_operation = matches_history.group(1)
196
+ history_step.append(grounded_operation)
197
+ if matches_actions:
198
+ action_operation = matches_actions.group(1)
199
+ history_action.append(action_operation)
200
+
201
+ # Extract bounding boxes from the response
202
+ box_pattern = r"box=\[\[?(\d+),(\d+),(\d+),(\d+)\]?\]"
203
+ matches = re.findall(box_pattern, response)
204
+ if matches:
205
+ boxes = [[int(x) / 1000 for x in match] for match in matches]
206
+
207
+ # Extract base name of the user's input image (without extension)
208
+ base_name = os.path.splitext(os.path.basename(img_path))[0]
209
+ # Construct the output file name with round number
210
+ output_file_name = f"{base_name}_{round_num}.png"
211
+ output_path = os.path.join(args.output_image_path, output_file_name)
212
+
213
+ draw_boxes_on_image(image, boxes, output_path)
214
+ #print(f"Annotated image saved at: {output_path}")
215
+ ans = {
216
+ 'query': f"Round {round_num} query:\n{query}",
217
+ 'response': response,
218
+ 'output_path': output_path
219
+ }
220
+ res.append(ans)
221
+ round_num += 1
222
+ print(res)
223
+ with open(args.output_json, "w") as file:
224
+ print("Writing to json file")
225
+ json.dump(res, file, indent=4)
226
+
227
+
228
+
229
+ if __name__ == "__main__":
230
+ main()