Wendy commited on
Commit
c75bc06
·
verified ·
1 Parent(s): 5e8edf0

Upload cogagent_infer.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. cogagent_infer.py +204 -0
cogagent_infer.py ADDED
@@ -0,0 +1,204 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ import re
4
+ import torch
5
+ from PIL import Image, ImageDraw
6
+ from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
7
+ from typing import List
8
+ import json
9
+
10
+ def draw_boxes_on_image(image: Image.Image, boxes: List[List[float]], save_path: str):
11
+ """
12
+ Draws red bounding boxes on the given image and saves it.
13
+
14
+ Parameters:
15
+ - image (PIL.Image.Image): The image on which to draw the bounding boxes.
16
+ - boxes (List[List[float]]): A list of bounding boxes, each defined as [x_min, y_min, x_max, y_max].
17
+ Coordinates are expected to be normalized (0 to 1).
18
+ - save_path (str): The path to save the updated image.
19
+
20
+ Description:
21
+ Each box coordinate is a fraction of the image dimension. This function converts them to actual pixel
22
+ coordinates and draws a red rectangle to mark the area. The annotated image is then saved to the specified path.
23
+ """
24
+ draw = ImageDraw.Draw(image)
25
+ for box in boxes:
26
+ x_min = int(box[0] * image.width)
27
+ y_min = int(box[1] * image.height)
28
+ x_max = int(box[2] * image.width)
29
+ y_max = int(box[3] * image.height)
30
+ draw.rectangle([x_min, y_min, x_max, y_max], outline="red", width=3)
31
+ image.save(save_path)
32
+
33
+
34
+ def main():
35
+ """
36
+ A continuous interactive demo using the CogAgent1.5 model with selectable format prompts.
37
+ The output_image_path is interpreted as a directory. For each round of interaction,
38
+ the annotated image will be saved in the directory with the filename:
39
+ {original_image_name_without_extension}_{round_number}.png
40
+
41
+ Example:
42
+ python cli_demo.py --model_dir THUDM/cogagent-9b-20241220 --platform "Mac" --max_length 4096 --top_k 1 \
43
+ --output_image_path ./results --format_key status_action_op_sensitive
44
+ """
45
+
46
+ parser = argparse.ArgumentParser(
47
+ description="Continuous interactive demo with CogAgent model and selectable format."
48
+ )
49
+ parser.add_argument(
50
+ "--model_dir", required=True, help="Path or identifier of the model."
51
+ )
52
+ parser.add_argument(
53
+ "--platform",
54
+ default="Mac",
55
+ help="Platform information string (e.g., 'Mac', 'WIN').",
56
+ )
57
+ parser.add_argument(
58
+ "--max_length", type=int, default=4096, help="Maximum generation length."
59
+ )
60
+ parser.add_argument(
61
+ "--top_k", type=int, default=1, help="Top-k sampling parameter."
62
+ )
63
+ parser.add_argument(
64
+ "--output_image_path",
65
+ default="results",
66
+ help="Directory to save the annotated images.",
67
+ )
68
+ parser.add_argument(
69
+ "--input_json",
70
+ default="/Users/baixuehai/Downloads/2025_2/AITM_Test_General_BBox_v0.json",
71
+ help="Directory to save the annotated images.",
72
+ )
73
+ parser.add_argument(
74
+ "--output_json",
75
+ default="/Users/baixuehai/Downloads/2025_2/AITM_Test_General_BBox_v0.json",
76
+ help="Directory to save the annotated images.",
77
+ )
78
+ parser.add_argument(
79
+ "--format_key",
80
+ default="action_op_sensitive",
81
+ help="Key to select the prompt format.",
82
+ )
83
+ args = parser.parse_args()
84
+
85
+ # Dictionary mapping format keys to format strings
86
+ format_dict = {
87
+ "action_op_sensitive": "(Answer in Action-Operation-Sensitive format.)",
88
+ "status_plan_action_op": "(Answer in Status-Plan-Action-Operation format.)",
89
+ "status_action_op_sensitive": "(Answer in Status-Action-Operation-Sensitive format.)",
90
+ "status_action_op": "(Answer in Status-Action-Operation format.)",
91
+ "action_op": "(Answer in Action-Operation format.)",
92
+ }
93
+
94
+ # Ensure the provided format_key is valid
95
+ if args.format_key not in format_dict:
96
+ raise ValueError(
97
+ f"Invalid format_key. Available keys are: {list(format_dict.keys())}"
98
+ )
99
+
100
+ # Ensure the output directory exists
101
+ os.makedirs(args.output_image_path, exist_ok=True)
102
+
103
+ # Load the tokenizer and model
104
+ tokenizer = AutoTokenizer.from_pretrained(args.model_dir, trust_remote_code=True)
105
+ model = AutoModelForCausalLM.from_pretrained(
106
+ args.model_dir,
107
+ torch_dtype=torch.bfloat16,
108
+ trust_remote_code=True,
109
+ device_map="auto",
110
+ # quantization_config=BitsAndBytesConfig(load_in_8bit=True), # For INT8 quantization
111
+ # quantization_config=BitsAndBytesConfig(load_in_4bit=True), # For INT4 quantization
112
+ ).eval()
113
+ # Initialize platform and selected format strings
114
+ platform_str = f"(Platform: {args.platform})\n"
115
+ format_str = format_dict[args.format_key]
116
+
117
+ # Initialize history lists
118
+ history_step = []
119
+ history_action = []
120
+
121
+ round_num = 1
122
+ with open(args.input_json, "r") as f:
123
+ data = json.load(f)
124
+ res = []
125
+ for x in data:
126
+ img_path = x['image']
127
+ image = Image.open(img_path).convert("RGB")
128
+ task = x['conversations'][0]['value']
129
+ # Verify history lengths match
130
+ if len(history_step) != len(history_action):
131
+ raise ValueError("Mismatch in lengths of history_step and history_action.")
132
+
133
+ # Format history steps for output
134
+ history_str = "\nHistory steps: "
135
+ for index, (step, action) in enumerate(zip(history_step, history_action)):
136
+ history_str += f"\n{index}. {step}\t{action}"
137
+
138
+ # Compose the query with task, platform, and selected format instructions
139
+ query = f"Task: {task}{history_str}\n{platform_str}{format_str}"
140
+
141
+ #print(f"Round {round_num} query:\n{query}")
142
+
143
+ inputs = tokenizer.apply_chat_template(
144
+ [{"role": "user", "image": image, "content": query}],
145
+ add_generation_prompt=True,
146
+ tokenize=True,
147
+ return_tensors="pt",
148
+ return_dict=True,
149
+ ).to(model.device)
150
+ # Generation parameters
151
+ gen_kwargs = {
152
+ "max_length": args.max_length,
153
+ "do_sample": True,
154
+ "top_k": args.top_k,
155
+ }
156
+
157
+ # Generate response
158
+ with torch.no_grad():
159
+ outputs = model.generate(**inputs, **gen_kwargs)
160
+ outputs = outputs[:, inputs["input_ids"].shape[1]:]
161
+ response = tokenizer.decode(outputs[0], skip_special_tokens=True)
162
+ #print(f"Model response:\n{response}")
163
+
164
+ # Extract grounded operation and action
165
+ grounded_pattern = r"Grounded Operation:\s*(.*)"
166
+ action_pattern = r"Action:\s*(.*)"
167
+ matches_history = re.search(grounded_pattern, response)
168
+ matches_actions = re.search(action_pattern, response)
169
+
170
+ if matches_history:
171
+ grounded_operation = matches_history.group(1)
172
+ history_step.append(grounded_operation)
173
+ if matches_actions:
174
+ action_operation = matches_actions.group(1)
175
+ history_action.append(action_operation)
176
+
177
+ # Extract bounding boxes from the response
178
+ box_pattern = r"box=\[\[?(\d+),(\d+),(\d+),(\d+)\]?\]"
179
+ matches = re.findall(box_pattern, response)
180
+ if matches:
181
+ boxes = [[int(x) / 1000 for x in match] for match in matches]
182
+
183
+ # Extract base name of the user's input image (without extension)
184
+ base_name = os.path.splitext(os.path.basename(img_path))[0]
185
+ # Construct the output file name with round number
186
+ output_file_name = f"{base_name}_{round_num}.png"
187
+ output_path = os.path.join(args.output_image_path, output_file_name)
188
+
189
+ draw_boxes_on_image(image, boxes, output_path)
190
+ #print(f"Annotated image saved at: {output_path}")
191
+ ans = {
192
+ 'query': query,
193
+ 'response': response,
194
+ 'output_path': output_path
195
+ }
196
+ res.append(ans)
197
+ round_num += 1
198
+ with open("args.output_json", "w", encoding="utf-8") as file:
199
+ json.dump(data, file, ensure_ascii=False, indent=4)
200
+
201
+
202
+
203
+ if __name__ == "__main__":
204
+ main()