from typing import Optional
import spaces
import gradio as gr
import numpy as np
import torch
from PIL import Image
import io
import base64, os
from huggingface_hub import snapshot_download
import traceback
import warnings
import sys
# Suppress specific warnings
warnings.filterwarnings("ignore", category=FutureWarning)
warnings.filterwarnings("ignore", message=".*_supports_sdpa.*")
# CRITICAL: Fix Florence2 model before any imports
def fix_florence2_import():
"""Pre-patch the Florence2 model class before it's imported"""
import importlib.util
import types
# Create a custom import hook
class Florence2ImportHook:
def find_spec(self, fullname, path, target=None):
if "florence2" in fullname.lower() or "modeling_florence2" in fullname:
return importlib.util.spec_from_loader(fullname, Florence2Loader())
return None
class Florence2Loader:
def create_module(self, spec):
return None
def exec_module(self, module):
# Load the original module
import importlib.machinery
import importlib.util
# Find the actual florence2 module
for path in sys.path:
florence_path = os.path.join(path, "modeling_florence2.py")
if os.path.exists(florence_path):
spec = importlib.util.spec_from_file_location("modeling_florence2", florence_path)
if spec and spec.loader:
spec.loader.exec_module(module)
# Patch the module after loading
if hasattr(module, 'Florence2ForConditionalGeneration'):
original_init = module.Florence2ForConditionalGeneration.__init__
def patched_init(self, config):
# Add the missing attribute before calling super().__init__
self._supports_sdpa = False
original_init(self, config)
module.Florence2ForConditionalGeneration.__init__ = patched_init
module.Florence2ForConditionalGeneration._supports_sdpa = False
break
# Install the import hook
hook = Florence2ImportHook()
sys.meta_path.insert(0, hook)
# Apply the fix before any model imports
try:
fix_florence2_import()
except Exception as e:
print(f"Warning: Could not apply import hook: {e}")
# Alternative fix: Monkey-patch transformers before importing utils
def monkey_patch_transformers():
"""Monkey patch transformers to handle _supports_sdpa"""
try:
import transformers.modeling_utils as modeling_utils
original_check = modeling_utils.PreTrainedModel._check_and_adjust_attn_implementation
def patched_check(self, *args, **kwargs):
# Add the attribute if missing
if not hasattr(self, '_supports_sdpa'):
self._supports_sdpa = False
try:
return original_check(self, *args, **kwargs)
except AttributeError as e:
if '_supports_sdpa' in str(e):
# Return a safe default
return "eager"
raise
modeling_utils.PreTrainedModel._check_and_adjust_attn_implementation = patched_check
# Also patch the getter
original_getattr = modeling_utils.PreTrainedModel.__getattribute__
def patched_getattr(self, name):
if name == '_supports_sdpa' and not hasattr(self, '_supports_sdpa'):
return False
return original_getattr(self, name)
modeling_utils.PreTrainedModel.__getattribute__ = patched_getattr
print("Successfully patched transformers for Florence2 compatibility")
except Exception as e:
print(f"Warning: Could not patch transformers: {e}")
# Apply the monkey patch
monkey_patch_transformers()
# Now import the utils after patching
from util.utils import check_ocr_box, get_yolo_model, get_som_labeled_img
# Download repository (if not already downloaded)
repo_id = "microsoft/OmniParser-v2.0"
local_dir = "weights"
if not os.path.exists(local_dir):
snapshot_download(repo_id=repo_id, local_dir=local_dir)
print(f"Repository downloaded to: {local_dir}")
else:
print(f"Weights already exist at: {local_dir}")
# Custom function to load caption model with proper error handling
def load_caption_model_safe(model_name="florence2", model_name_or_path="weights/icon_caption"):
"""Safely load caption model with multiple fallback methods"""
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
try:
# Method 1: Try the original function with patching
from util.utils import get_caption_model_processor
return get_caption_model_processor(model_name, model_name_or_path)
except AttributeError as e:
if '_supports_sdpa' in str(e):
print(f"SDPA error detected, trying alternative loading method...")
else:
raise
# Method 2: Load directly with specific configuration
try:
from transformers import AutoProcessor, AutoModelForCausalLM
print(f"Loading caption model from {model_name_or_path} with alternative method...")
# Load processor
processor = AutoProcessor.from_pretrained(
model_name_or_path,
trust_remote_code=True,
revision="main"
)
# Try to load model with different configurations
configs_to_try = [
{"attn_implementation": "eager", "use_cache": False},
{"use_flash_attention_2": False, "use_cache": False},
{"torch_dtype": torch.float32}, # Try float32 instead of float16
]
model = None
for config in configs_to_try:
try:
model = AutoModelForCausalLM.from_pretrained(
model_name_or_path,
trust_remote_code=True,
device_map="auto" if torch.cuda.is_available() else None,
**config
)
# Ensure the attribute exists
if not hasattr(model, '_supports_sdpa'):
model._supports_sdpa = False
print(f"Model loaded successfully with config: {config}")
break
except Exception as e:
print(f"Failed with config {config}: {e}")
continue
if model is None:
raise RuntimeError("Could not load model with any configuration")
# Move to device if needed
if device.type == 'cuda' and not next(model.parameters()).is_cuda:
model = model.to(device)
return {'model': model, 'processor': processor}
except Exception as e:
print(f"Error in alternative loading: {e}")
raise
# Load models
try:
print("Loading YOLO model...")
yolo_model = get_yolo_model(model_path='weights/icon_detect/model.pt')
print("YOLO model loaded successfully")
print("Loading caption model...")
caption_model_processor = load_caption_model_safe()
print("Caption model loaded successfully")
except Exception as e:
print(f"Critical error loading models: {e}")
print(traceback.format_exc())
caption_model_processor = None
# Don't raise here, let the UI handle it
# Markdown header text
MARKDOWN = """
# OmniParser V2 Proπ₯
π― AI-powered screen understanding tool that detects UI elements and extracts text with high accuracy.
π Supports both PaddleOCR and EasyOCR for flexible text extraction.
"""
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {DEVICE}")
# Custom CSS for UI enhancement
custom_css = """
body { background-color: #f0f2f5; }
.gradio-container { font-family: 'Segoe UI', sans-serif; max-width: 1400px; margin: auto; }
h1, h2, h3, h4 { color: #283E51; }
button { border-radius: 6px; transition: all 0.3s ease; }
button:hover { transform: translateY(-2px); box-shadow: 0 4px 12px rgba(0,0,0,0.15); }
.output-image { border: 2px solid #e1e4e8; border-radius: 8px; }
#input_image { border: 2px dashed #4a90e2; border-radius: 8px; }
#input_image:hover { border-color: #2c5aa0; }
.gr-box { border-radius: 8px; }
.gr-padded { padding: 16px; }
"""
@spaces.GPU
@torch.inference_mode()
def process(
image_input,
box_threshold,
iou_threshold,
use_paddleocr,
imgsz
) -> tuple:
"""Process image with error handling and validation"""
# Input validation
if image_input is None:
return None, "β οΈ Please upload an image for processing."
# Check if caption model is loaded
if caption_model_processor is None:
return None, "β οΈ Caption model not loaded. There was an error during initialization. Please check the logs."
try:
# Log processing parameters
print(f"Processing with parameters: box_threshold={box_threshold}, "
f"iou_threshold={iou_threshold}, use_paddleocr={use_paddleocr}, imgsz={imgsz}")
# Calculate overlay ratio based on input image width
image_width = image_input.size[0]
box_overlay_ratio = max(0.5, min(2.0, image_width / 3200))
draw_bbox_config = {
'text_scale': 0.8 * box_overlay_ratio,
'text_thickness': max(int(2 * box_overlay_ratio), 1),
'text_padding': max(int(3 * box_overlay_ratio), 1),
'thickness': max(int(3 * box_overlay_ratio), 1),
}
# Run OCR bounding box detection
try:
ocr_bbox_rslt, is_goal_filtered = check_ocr_box(
image_input,
display_img=False,
output_bb_format='xyxy',
goal_filtering=None,
easyocr_args={'paragraph': False, 'text_threshold': 0.9},
use_paddleocr=use_paddleocr
)
# Handle None result from OCR
if ocr_bbox_rslt is None:
print("OCR returned None, using empty results")
text, ocr_bbox = [], []
else:
text, ocr_bbox = ocr_bbox_rslt
# Validate OCR results
if text is None:
text = []
if ocr_bbox is None:
ocr_bbox = []
print(f"OCR found {len(text)} text regions")
except Exception as e:
print(f"OCR error: {e}, continuing with empty OCR results")
text, ocr_bbox = [], []
# Get labeled image and parsed content
try:
# Ensure the model has the required attribute
if isinstance(caption_model_processor, dict) and 'model' in caption_model_processor:
model = caption_model_processor['model']
if not hasattr(model, '_supports_sdpa'):
model._supports_sdpa = False
dino_labled_img, label_coordinates, parsed_content_list = get_som_labeled_img(
image_input,
yolo_model,
BOX_TRESHOLD=box_threshold,
output_coord_in_ratio=True,
ocr_bbox=ocr_bbox if ocr_bbox else [],
draw_bbox_config=draw_bbox_config,
caption_model_processor=caption_model_processor,
ocr_text=text if text else [],
iou_threshold=iou_threshold,
imgsz=imgsz
)
if dino_labled_img is None:
raise ValueError("Failed to generate labeled image")
except Exception as e:
print(f"Error in SOM processing: {e}")
print(traceback.format_exc())
return image_input, f"β οΈ Error during element detection: {str(e)}"
# Decode processed image from base64
try:
image = Image.open(io.BytesIO(base64.b64decode(dino_labled_img)))
print('Successfully decoded processed image')
except Exception as e:
print(f"Error decoding image: {e}")
return image_input, f"β οΈ Error decoding processed image: {str(e)}"
# Format parsed content list
if parsed_content_list and len(parsed_content_list) > 0:
parsed_text = "π― **Detected Elements:**\n\n"
for i, v in enumerate(parsed_content_list):
if v: # Only add non-empty content
parsed_text += f"**Icon {i}:** {v}\n"
else:
parsed_text = "βΉοΈ No UI elements detected. Try adjusting the detection thresholds."
print(f'Finished processing image. Found {len(parsed_content_list)} elements.')
return image, parsed_text
except Exception as e:
error_msg = f"β οΈ Unexpected error: {str(e)}"
print(f"Error during processing: {e}")
print(traceback.format_exc())
return None, error_msg
# Build Gradio UI
with gr.Blocks(css=custom_css, theme=gr.themes.Soft(), title="OmniParser V2 Pro") as demo:
gr.Markdown(MARKDOWN)
# Check if models loaded successfully
if caption_model_processor is None:
gr.Markdown("### β οΈ Warning: Caption model failed to load. Some features may not work.")
with gr.Row():
# Left sidebar: Upload and settings
with gr.Column(scale=1):
with gr.Accordion("π€ Upload Image & Settings", open=True):
image_input_component = gr.Image(
type='pil',
label='Upload Screenshot/UI Image',
elem_id="input_image"
)
gr.Markdown("### ποΈ Detection Settings")
with gr.Group():
box_threshold_component = gr.Slider(
label='π Box Threshold',
minimum=0.01,
maximum=1.0,
step=0.01,
value=0.05,
info="Lower values detect more elements"
)
iou_threshold_component = gr.Slider(
label='π² IOU Threshold',
minimum=0.01,
maximum=1.0,
step=0.01,
value=0.1,
info="Controls overlap filtering"
)
use_paddleocr_component = gr.Checkbox(
label='π€ Use PaddleOCR',
value=True,
info="β PaddleOCR | β EasyOCR"
)
imgsz_component = gr.Slider(
label='π Detection Image Size',
minimum=640,
maximum=1920,
step=32,
value=640,
info="Higher = better accuracy but slower"
)
submit_button_component = gr.Button(
value='π Process Image',
variant='primary',
size='lg'
)
gr.Markdown("### π‘ Quick Tips")
gr.Markdown("""
- **Mobile apps:** Use default settings
- **Desktop apps:** Try image size 1280
- **Complex UIs:** Lower box threshold to 0.03
- **Too many boxes:** Increase IOU threshold
""")
# Right main area: Results tabs
with gr.Column(scale=2):
with gr.Tabs():
with gr.Tab("πΌοΈ Annotated Image"):
image_output_component = gr.Image(
type='pil',
label='Processed Image with Annotations',
elem_classes=["output-image"]
)
with gr.Tab("π Extracted Elements"):
text_output_component = gr.Markdown(
value="*Parsed elements will appear here after processing...*",
elem_classes=["parsed-text"]
)
# Button click event
submit_button_component.click(
fn=process,
inputs=[
image_input_component,
box_threshold_component,
iou_threshold_component,
use_paddleocr_component,
imgsz_component
],
outputs=[image_output_component, text_output_component],
show_progress=True
)
# Launch with queue support
if __name__ == "__main__":
try:
# Set environment variables
os.environ['TRANSFORMERS_OFFLINE'] = '0'
os.environ['HF_HUB_OFFLINE'] = '0'
demo.queue(max_size=10)
demo.launch(
share=False,
show_error=True,
server_name="0.0.0.0",
server_port=7860
)
except Exception as e:
print(f"Failed to launch app: {e}")
print(traceback.format_exc())