import argparse import os import re import torch from PIL import Image, ImageDraw from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig from typing import List import json from tqdm import tqdm #class def draw_boxes_on_image(image: Image.Image, boxes: List[List[float]], save_path: str): """ Draws red bounding boxes on the given image and saves it. Parameters: - image (PIL.Image.Image): The image on which to draw the bounding boxes. - boxes (List[List[float]]): A list of bounding boxes, each defined as [x_min, y_min, x_max, y_max]. Coordinates are expected to be normalized (0 to 1). - save_path (str): The path to save the updated image. Description: Each box coordinate is a fraction of the image dimension. This function converts them to actual pixel coordinates and draws a red rectangle to mark the area. The annotated image is then saved to the specified path. """ draw = ImageDraw.Draw(image) for box in boxes: x_min = int(box[0] * image.width) y_min = int(box[1] * image.height) x_max = int(box[2] * image.width) y_max = int(box[3] * image.height) draw.rectangle([x_min, y_min, x_max, y_max], outline="red", width=3) image.save(save_path) def main(): """ A continuous interactive demo using the CogAgent1.5 model with selectable format prompts. The output_image_path is interpreted as a directory. For each round of interaction, the annotated image will be saved in the directory with the filename: {original_image_name_without_extension}_{round_number}.png Example: python cli_demo.py --model_dir THUDM/cogagent-9b-20241220 --platform "Mac" --max_length 4096 --top_k 1 \ --output_image_path ./results --format_key status_action_op_sensitive """ parser = argparse.ArgumentParser( description="Continuous interactive demo with CogAgent model and selectable format." ) parser.add_argument( "--model_dir", required=True, help="Path or identifier of the model." ) parser.add_argument( "--platform", default="Mac", help="Platform information string (e.g., 'Mac', 'WIN').", ) parser.add_argument( "--max_length", type=int, default=4096, help="Maximum generation length." ) parser.add_argument( "--top_k", type=int, default=1, help="Top-k sampling parameter." ) parser.add_argument( "--output_image_path", default="results", help="Directory to save the annotated images.", ) parser.add_argument( "--input_json", default="/Users/baixuehai/Downloads/2025_2/AITM_Test_General_BBox_v0.json", help="Directory to save the annotated images.", ) parser.add_argument( "--output_json", default="/Users/baixuehai/Downloads/2025_2/AITM_Test_General_BBox_v0.json", help="Directory to save the annotated images.", ) parser.add_argument( "--format_key", default="action_op_sensitive", help="Key to select the prompt format.", ) args = parser.parse_args() # Dictionary mapping format keys to format strings format_dict = { "action_op_sensitive": "(Answer in Action-Operation-Sensitive format.)", "status_plan_action_op": "(Answer in Status-Plan-Action-Operation format.)", "status_action_op_sensitive": "(Answer in Status-Action-Operation-Sensitive format.)", "status_action_op": "(Answer in Status-Action-Operation format.)", "action_op": "(Answer in Action-Operation format.)", } # Ensure the provided format_key is valid if args.format_key not in format_dict: raise ValueError( f"Invalid format_key. Available keys are: {list(format_dict.keys())}" ) # Ensure the output directory exists os.makedirs(args.output_image_path, exist_ok=True) # Load the tokenizer and model tokenizer = AutoTokenizer.from_pretrained(args.model_dir, trust_remote_code=True) model = AutoModelForCausalLM.from_pretrained( args.model_dir, torch_dtype=torch.bfloat16, trust_remote_code=True, device_map="auto", # quantization_config=BitsAndBytesConfig(load_in_8bit=True), # For INT8 quantization # quantization_config=BitsAndBytesConfig(load_in_4bit=True), # For INT4 quantization ).eval() # Initialize platform and selected format strings platform_str = f"(Platform: {args.platform})\n" format_str = format_dict[args.format_key] # Initialize history lists history_step = [] history_action = [] round_num = 1 with open(args.input_json, "r") as f: data = json.load(f) res = [] for i in tqdm(range(len(data))): x = data[i] img_path = x['image'] image = Image.open(img_path).convert("RGB") task = x['conversations'][0]['value'] # Verify history lengths match try: if len(history_step) != len(history_action): raise ValueError("Mismatch in lengths of history_step and history_action.") except ValueError as e: print(f"警告: {e} - 跳过当前案例") # Format history steps for output history_str = "\nHistory steps: " for index, (step, action) in enumerate(zip(history_step, history_action)): history_str += f"\n{index}. {step}\t{action}" # Compose the query with task, platform, and selected format instructions query = f"Task: {task}{history_str}\n{platform_str}{format_str}" #print(f"Round {round_num} query:\n{query}") inputs = tokenizer.apply_chat_template( [{"role": "user", "image": image, "content": query}], add_generation_prompt=True, tokenize=True, return_tensors="pt", return_dict=True, ).to(model.device) # Generation parameters gen_kwargs = { "max_length": args.max_length, "do_sample": True, "top_k": args.top_k, } # Generate response with torch.no_grad(): outputs = model.generate(**inputs, **gen_kwargs) outputs = outputs[:, inputs["input_ids"].shape[1]:] response = tokenizer.decode(outputs[0], skip_special_tokens=True) #print(f"Model response:\n{response}") # Extract grounded operation and action grounded_pattern = r"Grounded Operation:\s*(.*)" action_pattern = r"Action:\s*(.*)" matches_history = re.search(grounded_pattern, response) matches_actions = re.search(action_pattern, response) if matches_history: grounded_operation = matches_history.group(1) history_step.append(grounded_operation) if matches_actions: action_operation = matches_actions.group(1) history_action.append(action_operation) # Extract bounding boxes from the response box_pattern = r"box=\[\[?(\d+),(\d+),(\d+),(\d+)\]?\]" matches = re.findall(box_pattern, response) if matches: boxes = [[int(x) / 1000 for x in match] for match in matches] # Extract base name of the user's input image (without extension) base_name = os.path.splitext(os.path.basename(img_path))[0] # Construct the output file name with round number output_file_name = f"{base_name}_{round_num}.png" output_path = os.path.join(args.output_image_path, output_file_name) draw_boxes_on_image(image, boxes, output_path) #print(f"Annotated image saved at: {output_path}") ans = { 'query': f"Round {round_num} query:\n{query}", 'response': response, 'output_path': output_path } res.append(ans) round_num += 1 #print(res) with open(args.output_json, "w", encoding="utf-8") as file: json.dump(res, file, ensure_ascii=False, indent=4) # with open(args.output_json,"w", encoding="utf-8")as f: # json.dump(res,f,ensure_ascii=False, indent=4) if __name__ == "__main__": main()