from typing import Optional import spaces import gradio as gr import numpy as np import torch from PIL import Image import io import base64, os from huggingface_hub import snapshot_download import traceback import warnings import sys # Suppress specific warnings warnings.filterwarnings("ignore", category=FutureWarning) warnings.filterwarnings("ignore", message=".*_supports_sdpa.*") # CRITICAL: Fix Florence2 model before any imports def fix_florence2_import(): """Pre-patch the Florence2 model class before it's imported""" import importlib.util import types # Create a custom import hook class Florence2ImportHook: def find_spec(self, fullname, path, target=None): if "florence2" in fullname.lower() or "modeling_florence2" in fullname: return importlib.util.spec_from_loader(fullname, Florence2Loader()) return None class Florence2Loader: def create_module(self, spec): return None def exec_module(self, module): # Load the original module import importlib.machinery import importlib.util # Find the actual florence2 module for path in sys.path: florence_path = os.path.join(path, "modeling_florence2.py") if os.path.exists(florence_path): spec = importlib.util.spec_from_file_location("modeling_florence2", florence_path) if spec and spec.loader: spec.loader.exec_module(module) # Patch the module after loading if hasattr(module, 'Florence2ForConditionalGeneration'): original_init = module.Florence2ForConditionalGeneration.__init__ def patched_init(self, config): # Add the missing attribute before calling super().__init__ self._supports_sdpa = False original_init(self, config) module.Florence2ForConditionalGeneration.__init__ = patched_init module.Florence2ForConditionalGeneration._supports_sdpa = False break # Install the import hook hook = Florence2ImportHook() sys.meta_path.insert(0, hook) # Apply the fix before any model imports try: fix_florence2_import() except Exception as e: print(f"Warning: Could not apply import hook: {e}") # Alternative fix: Monkey-patch transformers before importing utils def monkey_patch_transformers(): """Monkey patch transformers to handle _supports_sdpa""" try: import transformers.modeling_utils as modeling_utils original_check = modeling_utils.PreTrainedModel._check_and_adjust_attn_implementation def patched_check(self, *args, **kwargs): # Add the attribute if missing if not hasattr(self, '_supports_sdpa'): self._supports_sdpa = False try: return original_check(self, *args, **kwargs) except AttributeError as e: if '_supports_sdpa' in str(e): # Return a safe default return "eager" raise modeling_utils.PreTrainedModel._check_and_adjust_attn_implementation = patched_check # Also patch the getter original_getattr = modeling_utils.PreTrainedModel.__getattribute__ def patched_getattr(self, name): if name == '_supports_sdpa' and not hasattr(self, '_supports_sdpa'): return False return original_getattr(self, name) modeling_utils.PreTrainedModel.__getattribute__ = patched_getattr print("Successfully patched transformers for Florence2 compatibility") except Exception as e: print(f"Warning: Could not patch transformers: {e}") # Apply the monkey patch monkey_patch_transformers() # Now import the utils after patching from util.utils import check_ocr_box, get_yolo_model, get_som_labeled_img # Download repository (if not already downloaded) repo_id = "microsoft/OmniParser-v2.0" local_dir = "weights" if not os.path.exists(local_dir): snapshot_download(repo_id=repo_id, local_dir=local_dir) print(f"Repository downloaded to: {local_dir}") else: print(f"Weights already exist at: {local_dir}") # Custom function to load caption model with proper error handling def load_caption_model_safe(model_name="florence2", model_name_or_path="weights/icon_caption"): """Safely load caption model with multiple fallback methods""" device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') try: # Method 1: Try the original function with patching from util.utils import get_caption_model_processor return get_caption_model_processor(model_name, model_name_or_path) except AttributeError as e: if '_supports_sdpa' in str(e): print(f"SDPA error detected, trying alternative loading method...") else: raise # Method 2: Load directly with specific configuration try: from transformers import AutoProcessor, AutoModelForCausalLM print(f"Loading caption model from {model_name_or_path} with alternative method...") # Load processor processor = AutoProcessor.from_pretrained( model_name_or_path, trust_remote_code=True, revision="main" ) # Try to load model with different configurations configs_to_try = [ {"attn_implementation": "eager", "use_cache": False}, {"use_flash_attention_2": False, "use_cache": False}, {"torch_dtype": torch.float32}, # Try float32 instead of float16 ] model = None for config in configs_to_try: try: model = AutoModelForCausalLM.from_pretrained( model_name_or_path, trust_remote_code=True, device_map="auto" if torch.cuda.is_available() else None, **config ) # Ensure the attribute exists if not hasattr(model, '_supports_sdpa'): model._supports_sdpa = False print(f"Model loaded successfully with config: {config}") break except Exception as e: print(f"Failed with config {config}: {e}") continue if model is None: raise RuntimeError("Could not load model with any configuration") # Move to device if needed if device.type == 'cuda' and not next(model.parameters()).is_cuda: model = model.to(device) return {'model': model, 'processor': processor} except Exception as e: print(f"Error in alternative loading: {e}") raise # Load models try: print("Loading YOLO model...") yolo_model = get_yolo_model(model_path='weights/icon_detect/model.pt') print("YOLO model loaded successfully") print("Loading caption model...") caption_model_processor = load_caption_model_safe() print("Caption model loaded successfully") except Exception as e: print(f"Critical error loading models: {e}") print(traceback.format_exc()) caption_model_processor = None # Don't raise here, let the UI handle it # Markdown header text MARKDOWN = """ # OmniParser V2 ProπŸ”₯

🎯 AI-powered screen understanding tool that detects UI elements and extracts text with high accuracy.

πŸ“ Supports both PaddleOCR and EasyOCR for flexible text extraction.

""" DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu') print(f"Using device: {DEVICE}") # Custom CSS for UI enhancement custom_css = """ body { background-color: #f0f2f5; } .gradio-container { font-family: 'Segoe UI', sans-serif; max-width: 1400px; margin: auto; } h1, h2, h3, h4 { color: #283E51; } button { border-radius: 6px; transition: all 0.3s ease; } button:hover { transform: translateY(-2px); box-shadow: 0 4px 12px rgba(0,0,0,0.15); } .output-image { border: 2px solid #e1e4e8; border-radius: 8px; } #input_image { border: 2px dashed #4a90e2; border-radius: 8px; } #input_image:hover { border-color: #2c5aa0; } .gr-box { border-radius: 8px; } .gr-padded { padding: 16px; } """ @spaces.GPU @torch.inference_mode() def process( image_input, box_threshold, iou_threshold, use_paddleocr, imgsz ) -> tuple: """Process image with error handling and validation""" # Input validation if image_input is None: return None, "⚠️ Please upload an image for processing." # Check if caption model is loaded if caption_model_processor is None: return None, "⚠️ Caption model not loaded. There was an error during initialization. Please check the logs." try: # Log processing parameters print(f"Processing with parameters: box_threshold={box_threshold}, " f"iou_threshold={iou_threshold}, use_paddleocr={use_paddleocr}, imgsz={imgsz}") # Calculate overlay ratio based on input image width image_width = image_input.size[0] box_overlay_ratio = max(0.5, min(2.0, image_width / 3200)) draw_bbox_config = { 'text_scale': 0.8 * box_overlay_ratio, 'text_thickness': max(int(2 * box_overlay_ratio), 1), 'text_padding': max(int(3 * box_overlay_ratio), 1), 'thickness': max(int(3 * box_overlay_ratio), 1), } # Run OCR bounding box detection try: ocr_bbox_rslt, is_goal_filtered = check_ocr_box( image_input, display_img=False, output_bb_format='xyxy', goal_filtering=None, easyocr_args={'paragraph': False, 'text_threshold': 0.9}, use_paddleocr=use_paddleocr ) # Handle None result from OCR if ocr_bbox_rslt is None: print("OCR returned None, using empty results") text, ocr_bbox = [], [] else: text, ocr_bbox = ocr_bbox_rslt # Validate OCR results if text is None: text = [] if ocr_bbox is None: ocr_bbox = [] print(f"OCR found {len(text)} text regions") except Exception as e: print(f"OCR error: {e}, continuing with empty OCR results") text, ocr_bbox = [], [] # Get labeled image and parsed content try: # Ensure the model has the required attribute if isinstance(caption_model_processor, dict) and 'model' in caption_model_processor: model = caption_model_processor['model'] if not hasattr(model, '_supports_sdpa'): model._supports_sdpa = False dino_labled_img, label_coordinates, parsed_content_list = get_som_labeled_img( image_input, yolo_model, BOX_TRESHOLD=box_threshold, output_coord_in_ratio=True, ocr_bbox=ocr_bbox if ocr_bbox else [], draw_bbox_config=draw_bbox_config, caption_model_processor=caption_model_processor, ocr_text=text if text else [], iou_threshold=iou_threshold, imgsz=imgsz ) if dino_labled_img is None: raise ValueError("Failed to generate labeled image") except Exception as e: print(f"Error in SOM processing: {e}") print(traceback.format_exc()) return image_input, f"⚠️ Error during element detection: {str(e)}" # Decode processed image from base64 try: image = Image.open(io.BytesIO(base64.b64decode(dino_labled_img))) print('Successfully decoded processed image') except Exception as e: print(f"Error decoding image: {e}") return image_input, f"⚠️ Error decoding processed image: {str(e)}" # Format parsed content list if parsed_content_list and len(parsed_content_list) > 0: parsed_text = "🎯 **Detected Elements:**\n\n" for i, v in enumerate(parsed_content_list): if v: # Only add non-empty content parsed_text += f"**Icon {i}:** {v}\n" else: parsed_text = "ℹ️ No UI elements detected. Try adjusting the detection thresholds." print(f'Finished processing image. Found {len(parsed_content_list)} elements.') return image, parsed_text except Exception as e: error_msg = f"⚠️ Unexpected error: {str(e)}" print(f"Error during processing: {e}") print(traceback.format_exc()) return None, error_msg # Build Gradio UI with gr.Blocks(css=custom_css, theme=gr.themes.Soft(), title="OmniParser V2 Pro") as demo: gr.Markdown(MARKDOWN) # Check if models loaded successfully if caption_model_processor is None: gr.Markdown("### ⚠️ Warning: Caption model failed to load. Some features may not work.") with gr.Row(): # Left sidebar: Upload and settings with gr.Column(scale=1): with gr.Accordion("πŸ“€ Upload Image & Settings", open=True): image_input_component = gr.Image( type='pil', label='Upload Screenshot/UI Image', elem_id="input_image" ) gr.Markdown("### πŸŽ›οΈ Detection Settings") with gr.Group(): box_threshold_component = gr.Slider( label='πŸ“Š Box Threshold', minimum=0.01, maximum=1.0, step=0.01, value=0.05, info="Lower values detect more elements" ) iou_threshold_component = gr.Slider( label='πŸ”² IOU Threshold', minimum=0.01, maximum=1.0, step=0.01, value=0.1, info="Controls overlap filtering" ) use_paddleocr_component = gr.Checkbox( label='πŸ”€ Use PaddleOCR', value=True, info="βœ“ PaddleOCR | βœ— EasyOCR" ) imgsz_component = gr.Slider( label='πŸ“ Detection Image Size', minimum=640, maximum=1920, step=32, value=640, info="Higher = better accuracy but slower" ) submit_button_component = gr.Button( value='πŸš€ Process Image', variant='primary', size='lg' ) gr.Markdown("### πŸ’‘ Quick Tips") gr.Markdown(""" - **Mobile apps:** Use default settings - **Desktop apps:** Try image size 1280 - **Complex UIs:** Lower box threshold to 0.03 - **Too many boxes:** Increase IOU threshold """) # Right main area: Results tabs with gr.Column(scale=2): with gr.Tabs(): with gr.Tab("πŸ–ΌοΈ Annotated Image"): image_output_component = gr.Image( type='pil', label='Processed Image with Annotations', elem_classes=["output-image"] ) with gr.Tab("πŸ“ Extracted Elements"): text_output_component = gr.Markdown( value="*Parsed elements will appear here after processing...*", elem_classes=["parsed-text"] ) # Button click event submit_button_component.click( fn=process, inputs=[ image_input_component, box_threshold_component, iou_threshold_component, use_paddleocr_component, imgsz_component ], outputs=[image_output_component, text_output_component], show_progress=True ) # Launch with queue support if __name__ == "__main__": try: # Set environment variables os.environ['TRANSFORMERS_OFFLINE'] = '0' os.environ['HF_HUB_OFFLINE'] = '0' demo.queue(max_size=10) demo.launch( share=False, show_error=True, server_name="0.0.0.0", server_port=7860 ) except Exception as e: print(f"Failed to launch app: {e}") print(traceback.format_exc())