File size: 6,377 Bytes
3aaaa71 6b6cc53 674ce16 6b6cc53 674ce16 6b6cc53 674ce16 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 |
---
library_name: transformers
tags: []
---
# Introduction
Reinforcement learning (RL) (e.g., GRPO) helps with grounding because of its inherent objective alignment—rewarding successful clicks—rather than encouraging long textual Chain-of-Thought (CoT) reasoning. Unlike approaches that rely heavily on verbose CoT reasoning, GRPO directly incentivizes actionable and grounded responses. Based on findings from our [blog](https://huggingface.co/blog/HelloKKMe/grounding-r1), we share state-of-the-art GUI grounding models trained using GRPO.
# Performance
We follow the standard evaluation protocol and benchmark our model on three challenging datasets. Our method consistently achieves the best results among all open-source model families. Below are the comparative results:
| **Model** | **Size** | **Open Source** | **ScreenSpot-V2** | **ScreenSpotPro** | **OSWORLD-G** |
|-------------------|:--------:|:---------------:|:-----------------:|:-----------------:|:-----------------:|
| OpenAI CUA | — | ❌ | 87.9 | 23.4 | — |
| Claude 3.7 | — | ❌ | 87.6 | 27.7 | — |
| JEDI-7B | 7B | ✅ | 91.7 | 39.5 | 54.1 |
| SE-GUI | 7B | ✅ | 90.3 | 47.0 | — |
| UI-TARS | 7B | ✅ | 91.6 | 35.7 | 47.5 |
| UI-TARS-1.5* | 7B | ✅ | 89.7* | 42.0* | 64.2* |
| UGround-v1-7B | 7B | ✅ | — | 31.1 | 36.4 |
| Qwen2.5-VL-32B-Instruct | 32B | ✅ | 91.9* | 48.0 | 59.6* | |
| UGround-v1-72B | 72B | ✅ | — | 34.5 | — |
| Qwen2.5-VL-72B-Instruct | 72B | ✅ | 94.00* | 53.3 | 62.2* |
| UI-TARS | 72B | ✅ | 90.3 | 38.1 | — |
| GTA1 (Ours) | 7B | ✅ | 92.4 <sub>*(∆ +2.7)*</sub> | 50.1<sub>*(∆ +8.1)*</sub> | 67.7 <sub>*(∆ +3.5)*</sub> |
| GTA1 (Ours) | 32B | ✅ | 93.2 <sub>*(∆ +1.3)*</sub> | 53.6 <sub>*(∆ +5.6)*</sub> | 61.9<sub>*(∆ +2.3)*</sub> |
| GTA1 (Ours) | 72B | ✅ | 94.8<sub>*(∆ +0.8)*</sub> | 58.4 <sub>*(∆ +5.1)*</sub> | 66.7<sub>*(∆ +4.5)*</sub> |
> **Note:**
> - Model size is indicated in billions (B) of parameters.
> - A dash (—) denotes results that are currently unavailable.
> - A superscript asterisk (﹡) denotes our evaluated result.
> - UI-TARS-1.5 7B, Qwen2.5-VL-32B-Instruct, and Qwen2.5-VL-72B-Instruct are applied as our baseline models.
> - ∆ indicates the performance improvement (∆) of our model compared to its baseline.
# Inference
Below is a code snippet demonstrating how to run inference using a trained model.
```python
from PIL import Image
from qwen_vl_utils import process_vision_info, smart_resize
from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor
import torch
import re
SYSTEM_PROMPT = '''
You are an expert UI element locator. Given a GUI image and a user's element description, provide the coordinates of the specified element as a single (x,y) point. The image resolution is height {height} and width {width}. For elements with area, return the center point.
Output the coordinate pair exactly:
(x,y)
'''
SYSTEM_PROMPT=SYSTEM_PROMPT.strip()
# Function to extract coordinates from model output
def extract_coordinates(raw_string):
try:
matches = re.findall(r"\((-?\d*\.?\d+),\s*(-?\d*\.?\d+)\)", raw_string)
return [tuple(map(int, match)) for match in matches][0]
except:
return 0,0
# Load model and processor
model_path = "HelloKKMe/GTA1-72B"
max_new_tokens = 32
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
model_path,
torch_dtype=torch.bfloat16,
attn_implementation="flash_attention_2",
device_map="auto"
)
processor = AutoProcessor.from_pretrained(
model_path,
min_pixels=3136,
max_pixels= 4096 * 2160
)
# Load and resize image
image = Image.open("file path")
instruction = "description" # Instruction for grounding
width, height = image.width, image.height
resized_height, resized_width = smart_resize(
image.height,
image.width,
factor=processor.image_processor.patch_size * processor.image_processor.merge_size,
min_pixels=processor.image_processor.min_pixels,
max_pixels=processor.image_processor.max_pixels,
)
resized_image = image.resize((resized_width, resized_height))
scale_x, scale_y = width / resized_width, height / resized_height
# Prepare system and user messages
system_message = {
"role": "system",
"content": SYSTEM_PROMPT.format(height=resized_height,width=resized_width)
}
user_message = {
"role": "user",
"content": [
{"type": "image", "image": resized_image},
{"type": "text", "text": instruction}
]
}
# Tokenize and prepare inputs
image_inputs, video_inputs = process_vision_info([system_message, user_message])
text = processor.apply_chat_template([system_message, user_message], tokenize=False, add_generation_prompt=True)
inputs = processor(text=[text], images=image_inputs, videos=video_inputs, padding=True, return_tensors="pt")
inputs = inputs.to(model.device)
# Generate prediction
output_ids = model.generate(**inputs, max_new_tokens=max_new_tokens, do_sample=False, temperature=1.0, use_cache=True)
generated_ids = [output_ids[len(input_ids):] for input_ids, output_ids in zip(inputs.input_ids, output_ids)]
output_text = processor.batch_decode(generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True)[0]
# Extract and rescale coordinates
pred_x, pred_y = extract_coordinates(output_text)
pred_x*=scale_x
pred_y*=scale_y
print(pred_x,pred_y)
```
Refer to our [code](https://github.com/Yan98/GTA1) for more details. |