|
--- |
|
library_name: transformers |
|
tags: [] |
|
--- |
|
|
|
# Introduction |
|
|
|
Reinforcement learning (RL) (e.g., GRPO) helps with grounding because of its inherent objective alignmentβrewarding successful clicksβrather than encouraging long textual Chain-of-Thought (CoT) reasoning. Unlike approaches that rely heavily on verbose CoT reasoning, GRPO directly incentivizes actionable and grounded responses. Based on findings from our [blog](https://huggingface.co/blog/HelloKKMe/grounding-r1), we share state-of-the-art GUI grounding models trained using GRPO. |
|
|
|
# Performance |
|
|
|
We follow the standard evaluation protocol and benchmark our model on three challenging datasets. Our method consistently achieves the best results among all open-source model families. Below are the comparative results: |
|
|
|
| **Model** | **Size** | **Open Source** | **ScreenSpot-V2** | **ScreenSpotPro** | **OSWORLD-G** | |
|
|-------------------|:--------:|:---------------:|:-----------------:|:-----------------:|:-----------------:| |
|
| OpenAI CUA | β | β | 87.9 | 23.4 | β | |
|
| Claude 3.7 | β | β | 87.6 | 27.7 | β | |
|
| JEDI-7B | 7B | β
| 91.7 | 39.5 | 54.1 | |
|
| SE-GUI | 7B | β
| 90.3 | 47.0 | β | |
|
| UI-TARS | 7B | β
| 91.6 | 35.7 | 47.5 | |
|
| UI-TARS-1.5* | 7B | β
| 89.7* | 42.0* | 64.2* | |
|
| UGround-v1-7B | 7B | β
| β | 31.1 | 36.4 | |
|
| Qwen2.5-VL-32B-Instruct | 32B | β
| 91.9* | 48.0 | 59.6* | | |
|
| UGround-v1-72B | 72B | β
| β | 34.5 | β | |
|
| Qwen2.5-VL-72B-Instruct | 72B | β
| 94.00* | 53.3 | 62.2* | |
|
| UI-TARS | 72B | β
| 90.3 | 38.1 | β | |
|
| GTA1 (Ours) | 7B | β
| 92.4 <sub>*(β +2.7)*</sub> | 50.1<sub>*(β +8.1)*</sub> | 67.7 <sub>*(β +3.5)*</sub> | |
|
| GTA1 (Ours) | 32B | β
| 93.2 <sub>*(β +1.3)*</sub> | 53.6 <sub>*(β +5.6)*</sub> | 61.9<sub>*(β +2.3)*</sub> | |
|
| GTA1 (Ours) | 72B | β
| 94.8<sub>*(β +0.8)*</sub> | 58.4 <sub>*(β +5.1)*</sub> | 66.7<sub>*(β +4.5)*</sub> | |
|
|
|
|
|
> **Note:** |
|
> - Model size is indicated in billions (B) of parameters. |
|
> - A dash (β) denotes results that are currently unavailable. |
|
> - A superscript asterisk (οΉ‘) denotes our evaluated result. |
|
> - UI-TARS-1.5 7B, Qwen2.5-VL-32B-Instruct, and Qwen2.5-VL-72B-Instruct are applied as our baseline models. |
|
> - β indicates the performance improvement (β) of our model compared to its baseline. |
|
|
|
# Inference |
|
Below is a code snippet demonstrating how to run inference using a trained model. |
|
|
|
```python |
|
from PIL import Image |
|
from qwen_vl_utils import process_vision_info, smart_resize |
|
from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor |
|
import torch |
|
import re |
|
|
|
SYSTEM_PROMPT = ''' |
|
You are an expert UI element locator. Given a GUI image and a user's element description, provide the coordinates of the specified element as a single (x,y) point. The image resolution is height {height} and width {width}. For elements with area, return the center point. |
|
|
|
Output the coordinate pair exactly: |
|
(x,y) |
|
''' |
|
SYSTEM_PROMPT=SYSTEM_PROMPT.strip() |
|
|
|
# Function to extract coordinates from model output |
|
def extract_coordinates(raw_string): |
|
try: |
|
matches = re.findall(r"\((-?\d*\.?\d+),\s*(-?\d*\.?\d+)\)", raw_string) |
|
return [tuple(map(int, match)) for match in matches][0] |
|
except: |
|
return 0,0 |
|
|
|
# Load model and processor |
|
model_path = "HelloKKMe/GTA1-72B" |
|
max_new_tokens = 32 |
|
|
|
model = Qwen2_5_VLForConditionalGeneration.from_pretrained( |
|
model_path, |
|
torch_dtype=torch.bfloat16, |
|
attn_implementation="flash_attention_2", |
|
device_map="auto" |
|
) |
|
processor = AutoProcessor.from_pretrained( |
|
model_path, |
|
min_pixels=3136, |
|
max_pixels= 4096 * 2160 |
|
) |
|
|
|
# Load and resize image |
|
image = Image.open("file path") |
|
instruction = "description" # Instruction for grounding |
|
width, height = image.width, image.height |
|
|
|
resized_height, resized_width = smart_resize( |
|
image.height, |
|
image.width, |
|
factor=processor.image_processor.patch_size * processor.image_processor.merge_size, |
|
min_pixels=processor.image_processor.min_pixels, |
|
max_pixels=processor.image_processor.max_pixels, |
|
) |
|
resized_image = image.resize((resized_width, resized_height)) |
|
scale_x, scale_y = width / resized_width, height / resized_height |
|
|
|
# Prepare system and user messages |
|
system_message = { |
|
"role": "system", |
|
"content": SYSTEM_PROMPT.format(height=resized_height,width=resized_width) |
|
} |
|
|
|
user_message = { |
|
"role": "user", |
|
"content": [ |
|
{"type": "image", "image": resized_image}, |
|
{"type": "text", "text": instruction} |
|
] |
|
} |
|
|
|
# Tokenize and prepare inputs |
|
image_inputs, video_inputs = process_vision_info([system_message, user_message]) |
|
text = processor.apply_chat_template([system_message, user_message], tokenize=False, add_generation_prompt=True) |
|
inputs = processor(text=[text], images=image_inputs, videos=video_inputs, padding=True, return_tensors="pt") |
|
inputs = inputs.to(model.device) |
|
|
|
# Generate prediction |
|
output_ids = model.generate(**inputs, max_new_tokens=max_new_tokens, do_sample=False, temperature=1.0, use_cache=True) |
|
generated_ids = [output_ids[len(input_ids):] for input_ids, output_ids in zip(inputs.input_ids, output_ids)] |
|
output_text = processor.batch_decode(generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True)[0] |
|
|
|
# Extract and rescale coordinates |
|
pred_x, pred_y = extract_coordinates(output_text) |
|
pred_x*=scale_x |
|
pred_y*=scale_y |
|
print(pred_x,pred_y) |
|
``` |
|
|
|
Refer to our [code](https://github.com/Yan98/GTA1) for more details. |