File size: 5,858 Bytes
0fc5095
bfc11d5
0fc5095
 
 
 
 
 
 
 
 
 
 
 
 
0375f07
 
 
 
 
 
 
 
 
 
 
 
0fc5095
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e285866
0fc5095
 
0375f07
0fc5095
 
0375f07
0fc5095
bfc11d5
0fc5095
 
 
 
 
 
 
 
 
 
39f8e6b
0fc5095
 
 
 
 
 
987cc76
 
591bd80
 
 
 
 
 
4e366f0
 
 
 
 
 
 
 
987cc76
4e366f0
 
 
 
 
 
 
0fc5095
4e366f0
0fc5095
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
987cc76
 
 
 
 
 
0fc5095
 
 
 
 
 
 
 
987cc76
 
 
 
 
0fc5095
 
987cc76
 
0fc5095
39f8e6b
987cc76
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
from typing import Optional
import spaces

import gradio as gr
import numpy as np
import torch
from PIL import Image
import io


import base64, os
from utils import check_ocr_box, get_yolo_model, get_caption_model_processor, get_som_labeled_img
import torch
from PIL import Image

# yolo_model = get_yolo_model(model_path='weights/icon_detect/best.pt')
# caption_model_processor = get_caption_model_processor(model_name="florence2", model_name_or_path="weights/icon_caption_florence")

from ultralytics import YOLO
yolo_model = YOLO('weights/icon_detect/best.pt').to('cuda')
from transformers import AutoProcessor, AutoModelForCausalLM 
processor = AutoProcessor.from_pretrained("microsoft/Florence-2-base", trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained("weights/icon_caption_florence", torch_dtype=torch.float16, trust_remote_code=True).to('cuda')
caption_model_processor = {'processor': processor, 'model': model}
print('finish loading model!!!')


platform = 'pc'
if platform == 'pc':
    draw_bbox_config = {
        'text_scale': 0.8,
        'text_thickness': 2,
        'text_padding': 2,
        'thickness': 2,
    }
elif platform == 'web':
    draw_bbox_config = {
        'text_scale': 0.8,
        'text_thickness': 2,
        'text_padding': 3,
        'thickness': 3,
    }
elif platform == 'mobile':
    draw_bbox_config = {
        'text_scale': 0.8,
        'text_thickness': 2,
        'text_padding': 3,
        'thickness': 3,
    }



MARKDOWN = """
# OmniParser for Pure Vision Based General GUI Agent 🔥
<div>
    <a href="https://arxiv.org/pdf/2408.00203">
        <img src="https://img.shields.io/badge/arXiv-2408.00203-b31b1b.svg" alt="Arxiv" style="display:inline-block;">
    </a>
</div>

OmniParser is a screen parsing tool to convert general GUI screen to structured elements. ✅
"""

# DEVICE = torch.device('cuda')

# @spaces.GPU
@torch.inference_mode()
# @torch.autocast(device_type="cuda", dtype=torch.bfloat16)
@spaces.GPU(duration=65)
def process(
    image_input,
    box_threshold,
    iou_threshold
) -> Optional[Image.Image]:

    image_save_path = 'imgs/saved_image_demo.png'
    image_input.save(image_save_path)
    # import pdb; pdb.set_trace()

    ocr_bbox_rslt, is_goal_filtered = check_ocr_box(image_save_path, display_img = False, output_bb_format='xyxy', goal_filtering=None, easyocr_args={'paragraph': False, 'text_threshold':0.9}, use_paddleocr=True)
    text, ocr_bbox = ocr_bbox_rslt
    # print('prompt:', prompt)
    dino_labled_img, label_coordinates, parsed_content_list = get_som_labeled_img(image_save_path, yolo_model, BOX_TRESHOLD = box_threshold, output_coord_in_ratio=True, ocr_bbox=ocr_bbox,draw_bbox_config=draw_bbox_config, caption_model_processor=caption_model_processor, ocr_text=text,iou_threshold=iou_threshold)
    image = Image.open(io.BytesIO(base64.b64decode(dino_labled_img)))
    print('finish processing')
    parsed_content_list = '\n'.join(parsed_content_list)
    
 # Format the coordinates output in a more readable way
    # coordinates_text = "Bounding Box Coordinates (x, y, width, height):\n"
    # for box_id, coords in sorted(label_coordinates.items(), key=lambda x: int(x[0])):
    #     # Convert numpy array to list and round values
    #     coords_list = coords.tolist()
    #     coords_formatted = [f"{coord:.1f}" for coord in coords_list]
    #     coordinates_text += f"Box {box_id}: [{coords_formatted[0]}, {coords_formatted[1]}, {coords_formatted[2]}, {coords_formatted[3]}]\n"

    combined_content = []
    for i, content in enumerate(parsed_content_list):
        if content.startswith('Text Box ID'):
            box_id = str(i)
        else:
            # Extract the ID number from Icon Box ID format
            box_id = content.split('Icon Box ID ')[1].split(':')[0]
        
        coords = label_coordinates.get(box_id)
        if coords is not None:  # Changed from 'if coords:' to handle numpy arrays
            coords_str = [round(x) for x in coords]  # Convert numpy values to rounded integers
            combined_content.append(f"{content} | Coordinates: {coords_str}")
        else:
            combined_content.append(content)


    return image, str(parsed_content_list), str(combined_content)


with gr.Blocks() as demo:
    gr.Markdown(MARKDOWN)
    with gr.Row():
        with gr.Column():
            image_input_component = gr.Image(
                type='pil', label='Upload image')
            box_threshold_component = gr.Slider(
                label='Box Threshold', minimum=0.01, maximum=1.0, step=0.01, value=0.05)
            iou_threshold_component = gr.Slider(
                label='IOU Threshold', minimum=0.01, maximum=1.0, step=0.01, value=0.1)
            submit_button_component = gr.Button(
                value='Submit', variant='primary')
        with gr.Column():
            image_output_component = gr.Image(type='pil', label='Image Output')
            text_output_component = gr.Textbox(label='Parsed screen elements', placeholder='Text Output')
            coordinates_output_component = gr.Textbox(
                label='Bounding Box Coordinates', 
                placeholder='Coordinates will appear here',
                lines=20,  # Increased lines to show more coordinates
                interactive=False  # Make it read-only
            )

    submit_button_component.click(
        fn=process,
        inputs=[
            image_input_component,
            box_threshold_component,
            iou_threshold_component
        ],
        outputs=[
            image_output_component, 
            text_output_component,
            coordinates_output_component
        ]
    )

demo.queue().launch(share=False)

# demo.launch(debug=False, show_error=True, share=True)
# demo.launch(share=True, server_port=7861, server_name='0.0.0.0')
# demo.queue().launch(share=False)