|
import argparse |
|
import os |
|
import re |
|
import torch |
|
from PIL import Image, ImageDraw |
|
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig |
|
from typing import List |
|
import json |
|
from tqdm import tqdm |
|
|
|
|
|
def draw_boxes_on_image(image: Image.Image, boxes: List[List[float]], save_path: str): |
|
""" |
|
Draws red bounding boxes on the given image and saves it. |
|
|
|
Parameters: |
|
- image (PIL.Image.Image): The image on which to draw the bounding boxes. |
|
- boxes (List[List[float]]): A list of bounding boxes, each defined as [x_min, y_min, x_max, y_max]. |
|
Coordinates are expected to be normalized (0 to 1). |
|
- save_path (str): The path to save the updated image. |
|
|
|
Description: |
|
Each box coordinate is a fraction of the image dimension. This function converts them to actual pixel |
|
coordinates and draws a red rectangle to mark the area. The annotated image is then saved to the specified path. |
|
""" |
|
draw = ImageDraw.Draw(image) |
|
for box in boxes: |
|
x_min = int(box[0] * image.width) |
|
y_min = int(box[1] * image.height) |
|
x_max = int(box[2] * image.width) |
|
y_max = int(box[3] * image.height) |
|
draw.rectangle([x_min, y_min, x_max, y_max], outline="red", width=3) |
|
image.save(save_path) |
|
|
|
|
|
def main(): |
|
""" |
|
A continuous interactive demo using the CogAgent1.5 model with selectable format prompts. |
|
The output_image_path is interpreted as a directory. For each round of interaction, |
|
the annotated image will be saved in the directory with the filename: |
|
{original_image_name_without_extension}_{round_number}.png |
|
|
|
Example: |
|
python cli_demo.py --model_dir THUDM/cogagent-9b-20241220 --platform "Mac" --max_length 4096 --top_k 1 \ |
|
--output_image_path ./results --format_key status_action_op_sensitive |
|
""" |
|
|
|
parser = argparse.ArgumentParser( |
|
description="Continuous interactive demo with CogAgent model and selectable format." |
|
) |
|
parser.add_argument( |
|
"--model_dir", required=True, help="Path or identifier of the model." |
|
) |
|
parser.add_argument( |
|
"--platform", |
|
default="Mac", |
|
help="Platform information string (e.g., 'Mac', 'WIN').", |
|
) |
|
parser.add_argument( |
|
"--max_length", type=int, default=4096, help="Maximum generation length." |
|
) |
|
parser.add_argument( |
|
"--top_k", type=int, default=1, help="Top-k sampling parameter." |
|
) |
|
parser.add_argument( |
|
"--output_image_path", |
|
default="results", |
|
help="Directory to save the annotated images.", |
|
) |
|
parser.add_argument( |
|
"--input_json", |
|
default="/Users/baixuehai/Downloads/2025_2/AITM_Test_General_BBox_v0.json", |
|
help="Directory to save the annotated images.", |
|
) |
|
parser.add_argument( |
|
"--output_json", |
|
default="/Users/baixuehai/Downloads/2025_2/AITM_Test_General_BBox_v0.json", |
|
help="Directory to save the annotated images.", |
|
) |
|
parser.add_argument( |
|
"--format_key", |
|
default="action_op_sensitive", |
|
help="Key to select the prompt format.", |
|
) |
|
args = parser.parse_args() |
|
|
|
|
|
format_dict = { |
|
"action_op_sensitive": "(Answer in Action-Operation-Sensitive format.)", |
|
"status_plan_action_op": "(Answer in Status-Plan-Action-Operation format.)", |
|
"status_action_op_sensitive": "(Answer in Status-Action-Operation-Sensitive format.)", |
|
"status_action_op": "(Answer in Status-Action-Operation format.)", |
|
"action_op": "(Answer in Action-Operation format.)", |
|
} |
|
|
|
|
|
if args.format_key not in format_dict: |
|
raise ValueError( |
|
f"Invalid format_key. Available keys are: {list(format_dict.keys())}" |
|
) |
|
|
|
|
|
os.makedirs(args.output_image_path, exist_ok=True) |
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(args.model_dir, trust_remote_code=True) |
|
model = AutoModelForCausalLM.from_pretrained( |
|
args.model_dir, |
|
torch_dtype=torch.bfloat16, |
|
trust_remote_code=True, |
|
device_map="auto", |
|
|
|
|
|
).eval() |
|
|
|
platform_str = f"(Platform: {args.platform})\n" |
|
format_str = format_dict[args.format_key] |
|
|
|
|
|
history_step = [] |
|
history_action = [] |
|
|
|
round_num = 1 |
|
with open(args.input_json, "r") as f: |
|
data = json.load(f) |
|
res = [] |
|
for i in tqdm(range(len(data))): |
|
x = data[i] |
|
img_path = x['image'] |
|
image = Image.open(img_path).convert("RGB") |
|
task = x['conversations'][0]['value'] |
|
|
|
try: |
|
if len(history_step) != len(history_action): |
|
raise ValueError("Mismatch in lengths of history_step and history_action.") |
|
except ValueError as e: |
|
print(f"警告: {e} - 跳过当前案例") |
|
|
|
|
|
history_str = "\nHistory steps: " |
|
for index, (step, action) in enumerate(zip(history_step, history_action)): |
|
history_str += f"\n{index}. {step}\t{action}" |
|
|
|
|
|
query = f"Task: {task}{history_str}\n{platform_str}{format_str}" |
|
|
|
|
|
|
|
inputs = tokenizer.apply_chat_template( |
|
[{"role": "user", "image": image, "content": query}], |
|
add_generation_prompt=True, |
|
tokenize=True, |
|
return_tensors="pt", |
|
return_dict=True, |
|
).to(model.device) |
|
|
|
gen_kwargs = { |
|
"max_length": args.max_length, |
|
"do_sample": True, |
|
"top_k": args.top_k, |
|
} |
|
|
|
|
|
with torch.no_grad(): |
|
outputs = model.generate(**inputs, **gen_kwargs) |
|
outputs = outputs[:, inputs["input_ids"].shape[1]:] |
|
response = tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
|
|
|
|
|
|
grounded_pattern = r"Grounded Operation:\s*(.*)" |
|
action_pattern = r"Action:\s*(.*)" |
|
matches_history = re.search(grounded_pattern, response) |
|
matches_actions = re.search(action_pattern, response) |
|
|
|
if matches_history: |
|
grounded_operation = matches_history.group(1) |
|
history_step.append(grounded_operation) |
|
if matches_actions: |
|
action_operation = matches_actions.group(1) |
|
history_action.append(action_operation) |
|
|
|
|
|
box_pattern = r"box=\[\[?(\d+),(\d+),(\d+),(\d+)\]?\]" |
|
matches = re.findall(box_pattern, response) |
|
if matches: |
|
boxes = [[int(x) / 1000 for x in match] for match in matches] |
|
|
|
|
|
base_name = os.path.splitext(os.path.basename(img_path))[0] |
|
|
|
output_file_name = f"{base_name}_{round_num}.png" |
|
output_path = os.path.join(args.output_image_path, output_file_name) |
|
|
|
draw_boxes_on_image(image, boxes, output_path) |
|
|
|
ans = { |
|
'query': f"Round {round_num} query:\n{query}", |
|
'response': response, |
|
'output_path': output_path |
|
} |
|
res.append(ans) |
|
round_num += 1 |
|
|
|
with open(args.output_json, "w", encoding="utf-8") as file: |
|
json.dump(res, file, ensure_ascii=False, indent=4) |
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
main() |