Wendy commited on
Commit
9144502
·
verified ·
1 Parent(s): 4b515a5

Upload cogagent_infer.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. cogagent_infer.py +213 -0
cogagent_infer.py ADDED
@@ -0,0 +1,213 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ import re
4
+ import torch
5
+ from PIL import Image, ImageDraw
6
+ from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
7
+ from typing import List
8
+ import json
9
+ from tqdm import tqdm
10
+
11
+ #class
12
+ def draw_boxes_on_image(image: Image.Image, boxes: List[List[float]], save_path: str):
13
+ """
14
+ Draws red bounding boxes on the given image and saves it.
15
+
16
+ Parameters:
17
+ - image (PIL.Image.Image): The image on which to draw the bounding boxes.
18
+ - boxes (List[List[float]]): A list of bounding boxes, each defined as [x_min, y_min, x_max, y_max].
19
+ Coordinates are expected to be normalized (0 to 1).
20
+ - save_path (str): The path to save the updated image.
21
+
22
+ Description:
23
+ Each box coordinate is a fraction of the image dimension. This function converts them to actual pixel
24
+ coordinates and draws a red rectangle to mark the area. The annotated image is then saved to the specified path.
25
+ """
26
+ draw = ImageDraw.Draw(image)
27
+ for box in boxes:
28
+ x_min = int(box[0] * image.width)
29
+ y_min = int(box[1] * image.height)
30
+ x_max = int(box[2] * image.width)
31
+ y_max = int(box[3] * image.height)
32
+ draw.rectangle([x_min, y_min, x_max, y_max], outline="red", width=3)
33
+ image.save(save_path)
34
+
35
+
36
+ def main():
37
+ """
38
+ A continuous interactive demo using the CogAgent1.5 model with selectable format prompts.
39
+ The output_image_path is interpreted as a directory. For each round of interaction,
40
+ the annotated image will be saved in the directory with the filename:
41
+ {original_image_name_without_extension}_{round_number}.png
42
+
43
+ Example:
44
+ python cli_demo.py --model_dir THUDM/cogagent-9b-20241220 --platform "Mac" --max_length 4096 --top_k 1 \
45
+ --output_image_path ./results --format_key status_action_op_sensitive
46
+ """
47
+
48
+ parser = argparse.ArgumentParser(
49
+ description="Continuous interactive demo with CogAgent model and selectable format."
50
+ )
51
+ parser.add_argument(
52
+ "--model_dir", required=True, help="Path or identifier of the model."
53
+ )
54
+ parser.add_argument(
55
+ "--platform",
56
+ default="Mac",
57
+ help="Platform information string (e.g., 'Mac', 'WIN').",
58
+ )
59
+ parser.add_argument(
60
+ "--max_length", type=int, default=4096, help="Maximum generation length."
61
+ )
62
+ parser.add_argument(
63
+ "--top_k", type=int, default=1, help="Top-k sampling parameter."
64
+ )
65
+ parser.add_argument(
66
+ "--output_image_path",
67
+ default="results",
68
+ help="Directory to save the annotated images.",
69
+ )
70
+ parser.add_argument(
71
+ "--input_json",
72
+ default="/Users/baixuehai/Downloads/2025_2/AITM_Test_General_BBox_v0.json",
73
+ help="Directory to save the annotated images.",
74
+ )
75
+ parser.add_argument(
76
+ "--output_json",
77
+ default="/Users/baixuehai/Downloads/2025_2/AITM_Test_General_BBox_v0.json",
78
+ help="Directory to save the annotated images.",
79
+ )
80
+ parser.add_argument(
81
+ "--format_key",
82
+ default="action_op_sensitive",
83
+ help="Key to select the prompt format.",
84
+ )
85
+ args = parser.parse_args()
86
+
87
+ # Dictionary mapping format keys to format strings
88
+ format_dict = {
89
+ "action_op_sensitive": "(Answer in Action-Operation-Sensitive format.)",
90
+ "status_plan_action_op": "(Answer in Status-Plan-Action-Operation format.)",
91
+ "status_action_op_sensitive": "(Answer in Status-Action-Operation-Sensitive format.)",
92
+ "status_action_op": "(Answer in Status-Action-Operation format.)",
93
+ "action_op": "(Answer in Action-Operation format.)",
94
+ }
95
+
96
+ # Ensure the provided format_key is valid
97
+ if args.format_key not in format_dict:
98
+ raise ValueError(
99
+ f"Invalid format_key. Available keys are: {list(format_dict.keys())}"
100
+ )
101
+
102
+ # Ensure the output directory exists
103
+ os.makedirs(args.output_image_path, exist_ok=True)
104
+
105
+ # Load the tokenizer and model
106
+ tokenizer = AutoTokenizer.from_pretrained(args.model_dir, trust_remote_code=True)
107
+ model = AutoModelForCausalLM.from_pretrained(
108
+ args.model_dir,
109
+ torch_dtype=torch.bfloat16,
110
+ trust_remote_code=True,
111
+ device_map="auto",
112
+ # quantization_config=BitsAndBytesConfig(load_in_8bit=True), # For INT8 quantization
113
+ # quantization_config=BitsAndBytesConfig(load_in_4bit=True), # For INT4 quantization
114
+ ).eval()
115
+ # Initialize platform and selected format strings
116
+ platform_str = f"(Platform: {args.platform})\n"
117
+ format_str = format_dict[args.format_key]
118
+
119
+ # Initialize history lists
120
+ history_step = []
121
+ history_action = []
122
+
123
+ round_num = 1
124
+ with open(args.input_json, "r") as f:
125
+ data = json.load(f)
126
+ res = []
127
+ for i in tqdm(range(len(data))):
128
+ x = data[i]
129
+ img_path = x['image']
130
+ image = Image.open(img_path).convert("RGB")
131
+ task = x['conversations'][0]['value']
132
+ # Verify history lengths match
133
+ try:
134
+ if len(history_step) != len(history_action):
135
+ raise ValueError("Mismatch in lengths of history_step and history_action.")
136
+ except ValueError as e:
137
+ print(f"警告: {e} - 跳过当前案例")
138
+
139
+ # Format history steps for output
140
+ history_str = "\nHistory steps: "
141
+ for index, (step, action) in enumerate(zip(history_step, history_action)):
142
+ history_str += f"\n{index}. {step}\t{action}"
143
+
144
+ # Compose the query with task, platform, and selected format instructions
145
+ query = f"Task: {task}{history_str}\n{platform_str}{format_str}"
146
+
147
+ #print(f"Round {round_num} query:\n{query}")
148
+
149
+ inputs = tokenizer.apply_chat_template(
150
+ [{"role": "user", "image": image, "content": query}],
151
+ add_generation_prompt=True,
152
+ tokenize=True,
153
+ return_tensors="pt",
154
+ return_dict=True,
155
+ ).to(model.device)
156
+ # Generation parameters
157
+ gen_kwargs = {
158
+ "max_length": args.max_length,
159
+ "do_sample": True,
160
+ "top_k": args.top_k,
161
+ }
162
+
163
+ # Generate response
164
+ with torch.no_grad():
165
+ outputs = model.generate(**inputs, **gen_kwargs)
166
+ outputs = outputs[:, inputs["input_ids"].shape[1]:]
167
+ response = tokenizer.decode(outputs[0], skip_special_tokens=True)
168
+ #print(f"Model response:\n{response}")
169
+
170
+ # Extract grounded operation and action
171
+ grounded_pattern = r"Grounded Operation:\s*(.*)"
172
+ action_pattern = r"Action:\s*(.*)"
173
+ matches_history = re.search(grounded_pattern, response)
174
+ matches_actions = re.search(action_pattern, response)
175
+
176
+ if matches_history:
177
+ grounded_operation = matches_history.group(1)
178
+ history_step.append(grounded_operation)
179
+ if matches_actions:
180
+ action_operation = matches_actions.group(1)
181
+ history_action.append(action_operation)
182
+
183
+ # Extract bounding boxes from the response
184
+ box_pattern = r"box=\[\[?(\d+),(\d+),(\d+),(\d+)\]?\]"
185
+ matches = re.findall(box_pattern, response)
186
+ if matches:
187
+ boxes = [[int(x) / 1000 for x in match] for match in matches]
188
+
189
+ # Extract base name of the user's input image (without extension)
190
+ base_name = os.path.splitext(os.path.basename(img_path))[0]
191
+ # Construct the output file name with round number
192
+ output_file_name = f"{base_name}_{round_num}.png"
193
+ output_path = os.path.join(args.output_image_path, output_file_name)
194
+
195
+ draw_boxes_on_image(image, boxes, output_path)
196
+ #print(f"Annotated image saved at: {output_path}")
197
+ ans = {
198
+ 'query': f"Round {round_num} query:\n{query}",
199
+ 'response': response,
200
+ 'output_path': output_path
201
+ }
202
+ res.append(ans)
203
+ round_num += 1
204
+ #print(res)
205
+ with open(args.output_json, "w", encoding="utf-8") as file:
206
+ json.dump(res, file, ensure_ascii=False, indent=4)
207
+ # with open(args.output_json,"w", encoding="utf-8")as f:
208
+ # json.dump(res,f,ensure_ascii=False, indent=4)
209
+
210
+
211
+
212
+ if __name__ == "__main__":
213
+ main()