from fastapi import FastAPI, HTTPException from pydantic import BaseModel, Field from typing import Optional, Dict, Any import torch from transformers import AutoProcessor, AutoModelForVisualQuestionAnswering from PIL import Image import io import base64 app = FastAPI( title="OmniParser API", description="API for parsing GUI elements from images", version="1.0.0" ) # Model class class OmniParser: def __init__(self): self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # Initialize processor and model self.processor = AutoProcessor.from_pretrained( "microsoft/Florence-2-base", trust_remote_code=True, cache_dir="/code/.cache" ) self.model = AutoModelForVisualQuestionAnswering.from_pretrained( "microsoft/OmniParser/icon_caption_florence", trust_remote_code=True, cache_dir="/code/.cache" ).to(self.device) @torch.inference_mode() def process_image( self, image: Image.Image, question: str = "What elements do you see in this GUI?", ) -> Dict[str, Any]: # Process image with the model inputs = self.processor(images=image, text=question, return_tensors="pt").to(self.device) outputs = self.model(**inputs) # Decode the outputs predicted_answer = self.processor.decode( outputs.logits.argmax(-1)[0], skip_special_tokens=True ) return { "parsed_elements": predicted_answer, "box_coordinates": {} # Placeholder for future box detection implementation } # Initialize model model = OmniParser() # Request/Response models class ParseRequest(BaseModel): image_data: str = Field(..., description="Base64 encoded image data") question: Optional[str] = Field( default="What elements do you see in this GUI?", description="Question to ask about the GUI" ) class ParseResponse(BaseModel): parsed_elements: str box_coordinates: dict output_image: Optional[str] def load_and_preprocess_image(image_data: bytes) -> Optional[Image.Image]: """Load and preprocess image from bytes data.""" try: image = Image.open(io.BytesIO(image_data)) return image except Exception as e: raise ValueError(f"Failed to load image: {str(e)}") def encode_output_image(image: Image.Image) -> str: """Encode PIL Image to base64 string.""" buffered = io.BytesIO() image.save(buffered, format="PNG") return base64.b64encode(buffered.getvalue()).decode() @app.get("/") async def root(): return { "message": "OmniParser API is running", "docs_url": "/docs" } @app.post("/parse", response_model=ParseResponse) async def parse_image(request: ParseRequest): try: # Decode base64 image image_bytes = base64.b64decode(request.image_data) image = load_and_preprocess_image(image_bytes) # Process with model result = model.process_image( image=image, question=request.question ) # Prepare response return ParseResponse( parsed_elements=result["parsed_elements"], box_coordinates=result["box_coordinates"], output_image=encode_output_image(image) ) except Exception as e: raise HTTPException(status_code=500, detail=str(e)) if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=8000)