Spaces:
Running
on
Zero
Running
on
Zero
import argparse | |
import os | |
import re | |
import threading | |
import time | |
from datetime import datetime, timedelta | |
import torch | |
from threading import Thread, Event | |
from PIL import Image, ImageDraw | |
import gradio as gr | |
from transformers import ( | |
AutoTokenizer, | |
AutoModelForCausalLM, | |
TextIteratorStreamer, | |
) | |
from typing import List | |
import spaces | |
stop_event = Event() | |
def delete_old_files(): | |
while True: | |
now = datetime.now() | |
cutoff = now - timedelta(minutes=10) | |
directories = ["./outputs", "./gradio_tmp"] | |
for directory in directories: | |
for filename in os.listdir(directory): | |
file_path = os.path.join(directory, filename) | |
if os.path.isfile(file_path): | |
file_mtime = datetime.fromtimestamp(os.path.getmtime(file_path)) | |
if file_mtime < cutoff: | |
os.remove(file_path) | |
time.sleep(600) | |
threading.Thread(target=delete_old_files, daemon=True).start() | |
def draw_boxes_on_image(image: Image.Image, boxes: List[List[float]], save_path: str): | |
draw = ImageDraw.Draw(image) | |
for box in boxes: | |
x_min = int(box[0] * image.width) | |
y_min = int(box[1] * image.height) | |
x_max = int(box[2] * image.width) | |
y_max = int(box[3] * image.height) | |
draw.rectangle([x_min, y_min, x_max, y_max], outline="red", width=3) | |
image.save(save_path) | |
def preprocess_messages(history, img_path, platform_str, format_str): | |
history_step = [] | |
for task, model_msg in history: | |
grounded_pattern = r"Grounded Operation:\s*(.*)" | |
matches_history = re.search(grounded_pattern, model_msg) | |
if matches_history: | |
grounded_operation = matches_history.group(1) | |
history_step.append(grounded_operation) | |
history_str = "\nHistory steps: " | |
if history_step: | |
for i, step in enumerate(history_step): | |
history_str += f"\n{i}. {step}" | |
if history: | |
task = history[-1][0] | |
else: | |
task = "No task provided" | |
query = f"Task: {task}{history_str}\n{platform_str}{format_str}" | |
image = Image.open(img_path).convert("RGB") | |
return query, image | |
def predict(history, max_length, img_path, platform_str, format_str, output_dir): | |
# Reset the stop_event at the start of prediction | |
stop_event.clear() | |
# Remember history length before this round (for rollback if stopped) | |
prev_len = len(history) | |
query, image = preprocess_messages(history, img_path, platform_str, format_str) | |
inputs = tokenizer.apply_chat_template( | |
[{"role": "user", "image": image, "content": query}], | |
add_generation_prompt=True, | |
tokenize=True, | |
return_tensors="pt", | |
return_dict=True, | |
).to(model.device) | |
streamer = TextIteratorStreamer( | |
tokenizer, timeout=60, skip_prompt=True, skip_special_tokens=True | |
) | |
generate_kwargs = { | |
"input_ids": inputs["input_ids"], | |
"attention_mask": inputs["attention_mask"], | |
"position_ids": inputs["position_ids"], | |
"images": inputs["images"], | |
"streamer": streamer, | |
"max_length": max_length, | |
"do_sample": True, | |
"top_k": 1, | |
} | |
t = Thread(target=model.generate, kwargs=generate_kwargs) | |
t.start() | |
for new_token in streamer: | |
# Check if stop event is set | |
if stop_event.is_set(): | |
# Stop generation immediately | |
# Rollback the last round user input | |
while len(history) > prev_len: | |
history.pop() | |
yield history, None | |
return | |
if new_token: | |
history[-1][1] += new_token | |
yield history, None | |
# If finished without stop event | |
response = history[-1][1] | |
box_pattern = r"box=\[\[?(\d+),(\d+),(\d+),(\d+)\]?\]" | |
matches = re.findall(box_pattern, response) | |
if matches: | |
boxes = [[int(x) / 1000 for x in match] for match in matches] | |
os.makedirs(output_dir, exist_ok=True) | |
base_name = os.path.splitext(os.path.basename(img_path))[0] | |
round_num = sum(1 for (u, m) in history if u and m) | |
output_path = os.path.join(output_dir, f"{base_name}_{round_num}.png") | |
image = Image.open(img_path).convert("RGB") | |
draw_boxes_on_image(image, boxes, output_path) | |
yield history, output_path | |
else: | |
yield history, None | |
def user(task, history): | |
return "", history + [[task, ""]] | |
def undo_last_round(history, output_img): | |
if history: | |
history.pop() | |
return history, None | |
def clear_all_history(): | |
return None, None | |
def stop_now(): | |
stop_event.set() | |
return gr.update(), gr.update() | |
def main(): | |
parser = argparse.ArgumentParser(description="CogAgent Gradio Demo") | |
parser.add_argument("--model_dir", default="THUDM/cogagent-9b-20241220", help="Path or identifier of the model.") | |
parser.add_argument("--format_key", default="action_op_sensitive", help="Key to select the prompt format.") | |
parser.add_argument("--platform", default="Mac", help="Platform information string.") | |
parser.add_argument("--output_dir", default="outputs", help="Directory to save annotated images.") | |
args = parser.parse_args() | |
format_dict = { | |
"action_op_sensitive": "(Answer in Action-Operation-Sensitive format.)", | |
"status_plan_action_op": "(Answer in Status-Plan-Action-Operation format.)", | |
"status_action_op_sensitive": "(Answer in Status-Action-Operation-Sensitive format.)", | |
"status_action_op": "(Answer in Status-Action-Operation format.)", | |
"action_op": "(Answer in Action-Operation format.)" | |
} | |
if args.format_key not in format_dict: | |
raise ValueError(f"Invalid format_key. Available keys: {list(format_dict.keys())}") | |
global tokenizer, model | |
tokenizer = AutoTokenizer.from_pretrained(args.model_dir, trust_remote_code=True) | |
model = AutoModelForCausalLM.from_pretrained( | |
args.model_dir, torch_dtype=torch.bfloat16, trust_remote_code=True, device_map="auto" | |
).eval() | |
platform_str = f"(Platform: {args.platform})\n" | |
format_str = format_dict[args.format_key] | |
with gr.Blocks(analytics_enabled=False) as demo: | |
gr.HTML("<h1 align='center'>CogAgent-9B-20241220 Demo</h1>") | |
gr.HTML( | |
""" | |
<p align='center' style='color:red;'>This demo is for learning and communication purposes only. Users must assume responsibility for the risks associated with AI-generated planning and operations.</p> | |
<p align='center' style='color:red;'>In this demo, the model assumes that the user is using a Mac operating system. Therefore, it is recommended to upload screenshots taken on a Mac.</p> | |
<p align='left' style='color:black;'>1. Upload a screenshot from your computer (must be from a Mac, and a full-screen screenshot).<br> | |
2. Provide your instructions to CogAgent (e.g., send a message to XXX).<br> | |
3. Wait for CogAgent to return specific operations. If bounding boxes (Bbox) are detected, they will be displayed in the image area on the right.</p> | |
<p align='left' style='color:black;'>The model will only return the next step's instructions. The online demo cannot control your computer. Please visit the <a href="https://github.com/THUDM/CogAgent">GitHub repository</a> for the full version of the demo.</p> | |
""" | |
) | |
with gr.Row(): | |
img_path = gr.Image(label="Upload a Screenshot", type="filepath", height=400) | |
output_img = gr.Image(type="filepath", label="Annotated Image(If Bbox Return)", height=400, interactive=False) | |
with gr.Row(): | |
with gr.Column(scale=2): | |
chatbot = gr.Chatbot(height=300) | |
task = gr.Textbox(show_label=True, placeholder="Input...", label="Task") | |
submitBtn = gr.Button("Submit") | |
with gr.Column(scale=1): | |
max_length = gr.Slider(0, 8192, value=1024, step=1.0, label="Maximum length", interactive=True) | |
undo_last_round_btn = gr.Button("Back to Last Round") | |
clear_history_btn = gr.Button("Clear All History") | |
# 添加红色的立刻中断按钮,点击后中断生成并回滚当前轮历史 | |
stop_now_btn = gr.Button("Stop Now", variant="stop") | |
submitBtn.click( | |
user, [task, chatbot], [task, chatbot], queue=False | |
).then( | |
predict, | |
[chatbot, max_length, img_path, gr.State(platform_str), gr.State(format_str), | |
gr.State(args.output_dir)], | |
[chatbot, output_img], | |
queue=True | |
) | |
undo_last_round_btn.click(undo_last_round, [chatbot, output_img], [chatbot, output_img], queue=False) | |
clear_history_btn.click(clear_all_history, None, [chatbot, output_img], queue=False) | |
stop_now_btn.click(stop_now, None, [chatbot, output_img], queue=False) | |
demo.queue() | |
demo.launch() | |
if __name__ == "__main__": | |
main() | |